diff --git a/.github/ISSUE_TEMPLATE/remote-vae-pilot-feedback.yml b/.github/ISSUE_TEMPLATE/remote-vae-pilot-feedback.yml
new file mode 100644
index 000000000000..c94d3bed9738
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/remote-vae-pilot-feedback.yml
@@ -0,0 +1,38 @@
+name: "\U0001F31F Remote VAE"
+description: Feedback for remote VAE pilot
+labels: [ "Remote VAE" ]
+
+body:
+ - type: textarea
+ id: positive
+ validations:
+ required: true
+ attributes:
+ label: Did you like the remote VAE solution?
+ description: |
+ If you liked it, we would appreciate it if you could elaborate what you liked.
+
+ - type: textarea
+ id: feedback
+ validations:
+ required: true
+ attributes:
+ label: What can be improved about the current solution?
+ description: |
+ Let us know the things you would like to see improved. Note that we will work optimizing the solution once the pilot is over and we have usage.
+
+ - type: textarea
+ id: others
+ validations:
+ required: true
+ attributes:
+ label: What other VAEs you would like to see if the pilot goes well?
+ description: |
+ Provide a list of the VAEs you would like to see in the future if the pilot goes well.
+
+ - type: textarea
+ id: additional-info
+ attributes:
+ label: Notify the members of the team
+ description: |
+ Tag the following folks when submitting this feedback: @hlky @sayakpaul
diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index d311c1c73f11..ff915e046946 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -38,6 +38,7 @@ jobs:
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install pandas peft
+ python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
- name: Environment
run: |
python utils/print_env.py
diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml
index 9f4776db4315..340d8a19e17a 100644
--- a/.github/workflows/build_docker_images.yml
+++ b/.github/workflows/build_docker_images.yml
@@ -34,7 +34,7 @@ jobs:
id: file_changes
uses: jitterbit/get-changed-files@v1
with:
- format: 'space-delimited'
+ format: "space-delimited"
token: ${{ secrets.GITHUB_TOKEN }}
- name: Build Changed Docker Images
@@ -67,6 +67,7 @@ jobs:
- diffusers-pytorch-cuda
- diffusers-pytorch-compile-cuda
- diffusers-pytorch-xformers-cuda
+ - diffusers-pytorch-minimum-cuda
- diffusers-flax-cpu
- diffusers-flax-tpu
- diffusers-onnxruntime-cpu
diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index 142dbb0f1e8f..2b39eea2fe5d 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -180,14 +180,128 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+ run_big_gpu_torch_tests:
+ name: Torch tests on big GPU
+ strategy:
+ fail-fast: false
+ max-parallel: 2
+ runs-on:
+ group: aws-g6e-xlarge-plus
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: NVIDIA-SMI
+ run: nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install peft@git+https://github.com/huggingface/peft.git
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ python -m uv pip install pytest-reportlog
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Selected Torch CUDA Test on big GPU
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ BIG_GPU_MEMORY: 40
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -m "big_gpu_with_torch_cuda" \
+ --make-reports=tests_big_gpu_torch_cuda \
+ --report-log=tests_big_gpu_torch_cuda.log \
+ tests/
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_big_gpu_torch_cuda_stats.txt
+ cat reports/tests_big_gpu_torch_cuda_failures_short.txt
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_cuda_big_gpu_test_reports
+ path: reports
+ - name: Generate Report and Notify Channel
+ if: always()
+ run: |
+ pip install slack_sdk tabulate
+ python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+
+ torch_minimum_version_cuda_tests:
+ name: Torch Minimum Version CUDA Tests
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-minimum-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install peft@git+https://github.com/huggingface/peft.git
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run PyTorch CUDA tests
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_torch_minimum_version_cuda \
+ tests/models/test_modeling_common.py \
+ tests/pipelines/test_pipelines_common.py \
+ tests/pipelines/test_pipeline_utils.py \
+ tests/pipelines/test_pipelines.py \
+ tests/pipelines/test_pipelines_auto.py \
+ tests/schedulers/test_schedulers.py \
+ tests/others
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_torch_minimum_version_cuda_stats.txt
+ cat reports/tests_torch_minimum_version_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_minimum_version_cuda_test_reports
+ path: reports
+
run_flax_tpu_tests:
name: Nightly Flax TPU Tests
- runs-on: docker-tpu
+ runs-on:
+ group: gcp-ct5lp-hightpu-8t
if: github.event_name == 'schedule'
container:
image: diffusers/diffusers-flax-tpu
- options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged
+ options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
defaults:
run:
shell: bash
@@ -291,6 +405,77 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+ run_nightly_quantization_tests:
+ name: Torch quantization nightly tests
+ strategy:
+ fail-fast: false
+ max-parallel: 2
+ matrix:
+ config:
+ - backend: "bitsandbytes"
+ test_location: "bnb"
+ additional_deps: ["peft"]
+ - backend: "gguf"
+ test_location: "gguf"
+ additional_deps: []
+ - backend: "torchao"
+ test_location: "torchao"
+ additional_deps: []
+ - backend: "optimum_quanto"
+ test_location: "quanto"
+ additional_deps: []
+ runs-on:
+ group: aws-g6e-xlarge-plus
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "20gb" --ipc host --gpus 0
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: NVIDIA-SMI
+ run: nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install -U ${{ matrix.config.backend }}
+ if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
+ python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
+ fi
+ python -m uv pip install pytest-reportlog
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: ${{ matrix.config.backend }} quantization tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ BIG_GPU_MEMORY: 40
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ --make-reports=tests_${{ matrix.config.backend }}_torch_cuda \
+ --report-log=tests_${{ matrix.config.backend }}_torch_cuda.log \
+ tests/quantization/${{ matrix.config.test_location }}
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_${{ matrix.config.backend }}_torch_cuda_stats.txt
+ cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_cuda_${{ matrix.config.backend }}_reports
+ path: reports
+ - name: Generate Report and Notify Channel
+ if: always()
+ run: |
+ pip install slack_sdk tabulate
+ python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+
# M1 runner currently not well supported
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
# run_nightly_tests_apple_m1:
@@ -329,7 +514,7 @@ jobs:
# shell: arch -arch arm64 bash {0}
# env:
# HF_HOME: /System/Volumes/Data/mnt/cache
-# HF_TOKEN: ${{ secrets.HF_TOKEN }}
+# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
@@ -385,7 +570,7 @@ jobs:
# shell: arch -arch arm64 bash {0}
# env:
# HF_HOME: /System/Volumes/Data/mnt/cache
-# HF_TOKEN: ${{ secrets.HF_TOKEN }}
+# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
@@ -405,4 +590,4 @@ jobs:
# if: always()
# run: |
# pip install slack_sdk tabulate
-# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
\ No newline at end of file
+# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
diff --git a/.github/workflows/pr_style_bot.yml b/.github/workflows/pr_style_bot.yml
new file mode 100644
index 000000000000..cf2439c4f2c4
--- /dev/null
+++ b/.github/workflows/pr_style_bot.yml
@@ -0,0 +1,51 @@
+name: PR Style Bot
+
+on:
+ issue_comment:
+ types: [created]
+
+permissions:
+ contents: write
+ pull-requests: write
+
+jobs:
+ style:
+ uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
+ with:
+ python_quality_dependencies: "[quality]"
+ pre_commit_script_name: "Download and Compare files from the main branch"
+ pre_commit_script: |
+ echo "Downloading the files from the main branch"
+
+ curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
+ curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
+ curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
+
+ echo "Compare the files and raise error if needed"
+
+ diff_failed=0
+ if ! diff -q main_Makefile Makefile; then
+ echo "Error: The Makefile has changed. Please ensure it matches the main branch."
+ diff_failed=1
+ fi
+
+ if ! diff -q main_setup.py setup.py; then
+ echo "Error: The setup.py has changed. Please ensure it matches the main branch."
+ diff_failed=1
+ fi
+
+ if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
+ echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
+ diff_failed=1
+ fi
+
+ if [ $diff_failed -eq 1 ]; then
+ echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
+ exit 1
+ fi
+
+ echo "No changes in the files. Proceeding..."
+ rm -rf main_Makefile main_setup.py main_check_doc_toc.py
+ style_command: "make style && make quality"
+ secrets:
+ bot_token: ${{ secrets.GITHUB_TOKEN }}
\ No newline at end of file
diff --git a/.github/workflows/pr_test_peft_backend.yml b/.github/workflows/pr_test_peft_backend.yml
deleted file mode 100644
index 190e5d26e6f3..000000000000
--- a/.github/workflows/pr_test_peft_backend.yml
+++ /dev/null
@@ -1,134 +0,0 @@
-name: Fast tests for PRs - PEFT backend
-
-on:
- pull_request:
- branches:
- - main
- paths:
- - "src/diffusers/**.py"
- - "tests/**.py"
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
-
-env:
- DIFFUSERS_IS_CI: yes
- OMP_NUM_THREADS: 4
- MKL_NUM_THREADS: 4
- PYTEST_TIMEOUT: 60
-
-jobs:
- check_code_quality:
- runs-on: ubuntu-22.04
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v4
- with:
- python-version: "3.8"
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install .[quality]
- - name: Check quality
- run: make quality
- - name: Check if failure
- if: ${{ failure() }}
- run: |
- echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
-
- check_repository_consistency:
- needs: check_code_quality
- runs-on: ubuntu-22.04
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v4
- with:
- python-version: "3.8"
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install .[quality]
- - name: Check repo consistency
- run: |
- python utils/check_copies.py
- python utils/check_dummies.py
- make deps_table_check_updated
- - name: Check if failure
- if: ${{ failure() }}
- run: |
- echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
-
- run_fast_tests:
- needs: [check_code_quality, check_repository_consistency]
- strategy:
- fail-fast: false
- matrix:
- lib-versions: ["main", "latest"]
-
-
- name: LoRA - ${{ matrix.lib-versions }}
-
- runs-on:
- group: aws-general-8-plus
-
- container:
- image: diffusers/diffusers-pytorch-cpu
- options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
-
- defaults:
- run:
- shell: bash
-
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- # TODO (sayakpaul, DN6): revisit `--no-deps`
- if [ "${{ matrix.lib-versions }}" == "main" ]; then
- python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
- python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- else
- python -m uv pip install -U peft --no-deps
- python -m uv pip install -U transformers accelerate --no-deps
- fi
-
- - name: Environment
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python utils/print_env.py
-
- - name: Run fast PyTorch LoRA CPU tests with PEFT backend
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v \
- --make-reports=tests_${{ matrix.lib-versions }} \
- tests/lora/
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v \
- --make-reports=tests_models_lora_${{ matrix.lib-versions }} \
- tests/models/ -k "lora"
-
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_${{ matrix.lib-versions }}_failures_short.txt
- cat reports/tests_models_lora_${{ matrix.lib-versions }}_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: pr_${{ matrix.lib-versions }}_test_reports
- path: reports
diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index ec3e55a5e882..10d3cb3248d9 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -2,8 +2,7 @@ name: Fast tests for PRs
on:
pull_request:
- branches:
- - main
+ branches: [main]
paths:
- "src/diffusers/**.py"
- "benchmarks/**.py"
@@ -64,6 +63,7 @@ jobs:
run: |
python utils/check_copies.py
python utils/check_dummies.py
+ python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
@@ -120,7 +120,8 @@ jobs:
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
- python -m uv pip install accelerate
+ pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
@@ -234,3 +235,68 @@ jobs:
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
+
+ run_lora_tests:
+ needs: [check_code_quality, check_repository_consistency]
+ strategy:
+ fail-fast: false
+
+ name: LoRA tests with PEFT main
+
+ runs-on:
+ group: aws-general-8-plus
+
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+
+ defaults:
+ run:
+ shell: bash
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ # TODO (sayakpaul, DN6): revisit `--no-deps`
+ python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
+ python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+ python -m uv pip install -U tokenizers
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run fast PyTorch LoRA tests with PEFT
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v \
+ --make-reports=tests_peft_main \
+ tests/lora/
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -s -v \
+ --make-reports=tests_models_lora_peft_main \
+ tests/models/ -k "lora"
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_lora_failures_short.txt
+ cat reports/tests_models_lora_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_main_test_reports
+ path: reports
+
diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml
new file mode 100644
index 000000000000..87d51773888e
--- /dev/null
+++ b/.github/workflows/pr_tests_gpu.yml
@@ -0,0 +1,296 @@
+name: Fast GPU Tests on PR
+
+on:
+ pull_request:
+ branches: main
+ paths:
+ - "src/diffusers/models/modeling_utils.py"
+ - "src/diffusers/models/model_loading_utils.py"
+ - "src/diffusers/pipelines/pipeline_utils.py"
+ - "src/diffusers/pipeline_loading_utils.py"
+ - "src/diffusers/loaders/lora_base.py"
+ - "src/diffusers/loaders/lora_pipeline.py"
+ - "src/diffusers/loaders/peft.py"
+ - "tests/pipelines/test_pipelines_common.py"
+ - "tests/models/test_modeling_common.py"
+ workflow_dispatch:
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ DIFFUSERS_IS_CI: yes
+ OMP_NUM_THREADS: 8
+ MKL_NUM_THREADS: 8
+ HF_HUB_ENABLE_HF_TRANSFER: 1
+ PYTEST_TIMEOUT: 600
+ PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
+
+jobs:
+ check_code_quality:
+ runs-on: ubuntu-22.04
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install .[quality]
+ - name: Check quality
+ run: make quality
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
+
+ check_repository_consistency:
+ needs: check_code_quality
+ runs-on: ubuntu-22.04
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install .[quality]
+ - name: Check repo consistency
+ run: |
+ python utils/check_copies.py
+ python utils/check_dummies.py
+ python utils/check_support_list.py
+ make deps_table_check_updated
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
+
+ setup_torch_cuda_pipeline_matrix:
+ needs: [check_code_quality, check_repository_consistency]
+ name: Setup Torch Pipelines CUDA Slow Tests Matrix
+ runs-on:
+ group: aws-general-8-plus
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ outputs:
+ pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Fetch Pipeline Matrix
+ id: fetch_pipeline_matrix
+ run: |
+ matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
+ echo $matrix
+ echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
+ - name: Pipeline Tests Artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: test-pipelines.json
+ path: reports
+
+ torch_pipelines_cuda_tests:
+ name: Torch Pipelines CUDA Tests
+ needs: setup_torch_cuda_pipeline_matrix
+ strategy:
+ fail-fast: false
+ max-parallel: 8
+ matrix:
+ module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Extract tests
+ id: extract_tests
+ run: |
+ pattern=$(python utils/extract_tests_from_mixin.py --type pipeline)
+ echo "$pattern" > /tmp/test_pattern.txt
+ echo "pattern_file=/tmp/test_pattern.txt" >> $GITHUB_OUTPUT
+
+ - name: PyTorch CUDA checkpoint tests on Ubuntu
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ if [ "${{ matrix.module }}" = "ip_adapters" ]; then
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \
+ tests/pipelines/${{ matrix.module }}
+ else
+ pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx and $pattern" \
+ --make-reports=tests_pipeline_${{ matrix.module }}_cuda \
+ tests/pipelines/${{ matrix.module }}
+ fi
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
+ cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pipeline_${{ matrix.module }}_test_reports
+ path: reports
+
+ torch_cuda_tests:
+ name: Torch CUDA Tests
+ needs: [check_code_quality, check_repository_consistency]
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ defaults:
+ run:
+ shell: bash
+ strategy:
+ fail-fast: false
+ max-parallel: 2
+ matrix:
+ module: [models, schedulers, lora, others]
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install peft@git+https://github.com/huggingface/peft.git
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Extract tests
+ id: extract_tests
+ run: |
+ pattern=$(python utils/extract_tests_from_mixin.py --type ${{ matrix.module }})
+ echo "$pattern" > /tmp/test_pattern.txt
+ echo "pattern_file=/tmp/test_pattern.txt" >> $GITHUB_OUTPUT
+
+ - name: Run PyTorch CUDA tests
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
+ if [ -z "$pattern" ]; then
+ python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
+ --make-reports=tests_torch_cuda_${{ matrix.module }}
+ else
+ python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
+ --make-reports=tests_torch_cuda_${{ matrix.module }}
+ fi
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_torch_cuda_${{ matrix.module }}_stats.txt
+ cat reports/tests_torch_cuda_${{ matrix.module }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_cuda_test_reports_${{ matrix.module }}
+ path: reports
+
+ run_examples_tests:
+ name: Examples PyTorch CUDA tests on Ubuntu
+ needs: [check_code_quality, check_repository_consistency]
+ runs-on:
+ group: aws-g4dn-2xlarge
+
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --gpus 0 --shm-size "16gb" --ipc host
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
+ run: |
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+ python -m uv pip install -e [quality,test,training]
+
+ - name: Environment
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python utils/print_env.py
+
+ - name: Run example tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install timm
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/examples_torch_cuda_stats.txt
+ cat reports/examples_torch_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: examples_test_reports
+ path: reports
+
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index f07e6cda0d59..abf825eaa7a0 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -81,9 +81,9 @@ jobs:
- name: Environment
run: |
python utils/print_env.py
- - name: Slow PyTorch CUDA checkpoint tests on Ubuntu
+ - name: PyTorch CUDA checkpoint tests on Ubuntu
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -137,7 +137,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -161,10 +161,11 @@ jobs:
flax_tpu_tests:
name: Flax TPU Tests
- runs-on: docker-tpu
+ runs-on:
+ group: gcp-ct5lp-hightpu-8t
container:
image: diffusers/diffusers-flax-tpu
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
+ options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
defaults:
run:
shell: bash
@@ -184,9 +185,9 @@ jobs:
run: |
python utils/print_env.py
- - name: Run slow Flax TPU tests
+ - name: Run Flax TPU tests
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
@@ -232,9 +233,9 @@ jobs:
run: |
python utils/print_env.py
- - name: Run slow ONNXRuntime CUDA tests
+ - name: Run ONNXRuntime CUDA tests
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
@@ -282,7 +283,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
@@ -325,7 +326,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
@@ -348,7 +349,6 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
-
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -358,7 +358,6 @@ jobs:
- name: NVIDIA-SMI
run: |
nvidia-smi
-
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
@@ -371,7 +370,7 @@ jobs:
- name: Run example tests on GPU
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml
index 8d521074a08f..5fd3b78be7df 100644
--- a/.github/workflows/push_tests_mps.yml
+++ b/.github/workflows/push_tests_mps.yml
@@ -46,7 +46,7 @@ jobs:
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pip install --upgrade pip uv
- ${CONDA_RUN} python -m uv pip install -e [quality,test]
+ ${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m uv pip install transformers --upgrade
diff --git a/.github/workflows/pypi_publish.yaml b/.github/workflows/pypi_publish.yaml
index 33a5bb5640f2..dc36b6b024c5 100644
--- a/.github/workflows/pypi_publish.yaml
+++ b/.github/workflows/pypi_publish.yaml
@@ -68,7 +68,7 @@ jobs:
- name: Test installing diffusers and importing
run: |
pip install diffusers && pip uninstall diffusers -y
- pip install -i https://testpypi.python.org/pypi diffusers
+ pip install -i https://test.pypi.org/simple/ diffusers
python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml
index a8a6f2699dca..27bd9bd9bb42 100644
--- a/.github/workflows/release_tests_fast.yml
+++ b/.github/workflows/release_tests_fast.yml
@@ -81,7 +81,7 @@ jobs:
python utils/print_env.py
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -135,7 +135,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -157,6 +157,63 @@ jobs:
name: torch_cuda_${{ matrix.module }}_test_reports
path: reports
+ torch_minimum_version_cuda_tests:
+ name: Torch Minimum Version CUDA Tests
+ runs-on:
+ group: aws-g4dn-2xlarge
+ container:
+ image: diffusers/diffusers-pytorch-minimum-cuda
+ options: --shm-size "16gb" --ipc host --gpus 0
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
+ python -m uv pip install -e [quality,test]
+ python -m uv pip install peft@git+https://github.com/huggingface/peft.git
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run PyTorch CUDA tests
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ run: |
+ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_torch_minimum_cuda \
+ tests/models/test_modeling_common.py \
+ tests/pipelines/test_pipelines_common.py \
+ tests/pipelines/test_pipeline_utils.py \
+ tests/pipelines/test_pipelines.py \
+ tests/pipelines/test_pipelines_auto.py \
+ tests/schedulers/test_schedulers.py \
+ tests/others
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_torch_minimum_version_cuda_stats.txt
+ cat reports/tests_torch_minimum_version_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_minimum_version_cuda_test_reports
+ path: reports
+
flax_tpu_tests:
name: Flax TPU Tests
runs-on: docker-tpu
@@ -184,7 +241,7 @@ jobs:
- name: Run slow Flax TPU tests
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
@@ -232,7 +289,7 @@ jobs:
- name: Run slow ONNXRuntime CUDA tests
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
@@ -280,7 +337,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
@@ -323,7 +380,7 @@ jobs:
python utils/print_env.py
- name: Run example tests on GPU
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
@@ -369,7 +426,7 @@ jobs:
- name: Run example tests on GPU
env:
- HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
diff --git a/.github/workflows/run_tests_from_a_pr.yml b/.github/workflows/run_tests_from_a_pr.yml
index 1e736e543089..94fbb2d297c5 100644
--- a/.github/workflows/run_tests_from_a_pr.yml
+++ b/.github/workflows/run_tests_from_a_pr.yml
@@ -7,8 +7,8 @@ on:
default: 'diffusers/diffusers-pytorch-cuda'
description: 'Name of the Docker image'
required: true
- branch:
- description: 'PR Branch to test on'
+ pr_number:
+ description: 'PR number to test on'
required: true
test:
description: 'Tests to run (e.g.: `tests/models`).'
@@ -43,8 +43,8 @@ jobs:
exit 1
fi
- if [[ ! "$PY_TEST" =~ ^tests/(models|pipelines) ]]; then
- echo "Error: The input string must contain either 'models' or 'pipelines' after 'tests/'."
+ if [[ ! "$PY_TEST" =~ ^tests/(models|pipelines|lora) ]]; then
+ echo "Error: The input string must contain either 'models', 'pipelines', or 'lora' after 'tests/'."
exit 1
fi
@@ -53,13 +53,13 @@ jobs:
exit 1
fi
echo "$PY_TEST"
+
+ shell: bash -e {0}
- name: Checkout PR branch
uses: actions/checkout@v4
with:
- ref: ${{ github.event.inputs.branch }}
- repository: ${{ github.event.pull_request.head.repo.full_name }}
-
+ ref: refs/pull/${{ inputs.pr_number }}/head
- name: Install pytest
run: |
diff --git a/.github/workflows/ssh-runner.yml b/.github/workflows/ssh-runner.yml
index 0d4fe1578ba6..fd65598a53a7 100644
--- a/.github/workflows/ssh-runner.yml
+++ b/.github/workflows/ssh-runner.yml
@@ -4,12 +4,13 @@ on:
workflow_dispatch:
inputs:
runner_type:
- description: 'Type of runner to test (aws-g6-4xlarge-plus: a10 or aws-g4dn-2xlarge: t4)'
+ description: 'Type of runner to test (aws-g6-4xlarge-plus: a10, aws-g4dn-2xlarge: t4, aws-g6e-xlarge-plus: L40)'
type: choice
required: true
options:
- aws-g6-4xlarge-plus
- aws-g4dn-2xlarge
+ - aws-g6e-xlarge-plus
docker_image:
description: 'Name of the Docker image'
required: true
diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml
index 44f821ea84ed..4743dc352455 100644
--- a/.github/workflows/trufflehog.yml
+++ b/.github/workflows/trufflehog.yml
@@ -13,3 +13,6 @@ jobs:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
+ with:
+ extra_args: --results=verified,unknown
+
diff --git a/README.md b/README.md
index b99ca828e4d0..dac3b3598aaf 100644
--- a/README.md
+++ b/README.md
@@ -112,9 +112,9 @@ Check out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to l
| **Documentation** | **What can I learn?** |
|---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [Tutorial](https://huggingface.co/docs/diffusers/tutorials/tutorial_overview) | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model. |
-| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading_overview) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. |
-| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/pipeline_overview) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. |
-| [Optimization](https://huggingface.co/docs/diffusers/optimization/opt_overview) | Guides for how to optimize your diffusion model to run faster and consume less memory. |
+| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. |
+| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/overview_techniques) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. |
+| [Optimization](https://huggingface.co/docs/diffusers/optimization/fp16) | Guides for how to optimize your diffusion model to run faster and consume less memory. |
| [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques. |
## Contribution
diff --git a/docker/diffusers-onnxruntime-cuda/Dockerfile b/docker/diffusers-onnxruntime-cuda/Dockerfile
index 6124172e109e..bd1d871033c9 100644
--- a/docker/diffusers-onnxruntime-cuda/Dockerfile
+++ b/docker/diffusers-onnxruntime-cuda/Dockerfile
@@ -28,7 +28,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
- "torch<2.5.0" \
+ torch \
torchvision \
torchaudio \
"onnxruntime-gpu>=1.13.1" \
diff --git a/docker/diffusers-pytorch-compile-cuda/Dockerfile b/docker/diffusers-pytorch-compile-cuda/Dockerfile
index 9d7578f5a4dc..cb4a9c0f9896 100644
--- a/docker/diffusers-pytorch-compile-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-compile-cuda/Dockerfile
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
- "torch<2.5.0" \
+ torch \
torchvision \
torchaudio \
invisible_watermark && \
diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile
index 1b39e58ca273..8d98c52598d2 100644
--- a/docker/diffusers-pytorch-cpu/Dockerfile
+++ b/docker/diffusers-pytorch-cpu/Dockerfile
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
- "torch<2.5.0" \
+ torch \
torchvision \
torchaudio \
invisible_watermark \
diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile
index 7317ef642aa5..695f5ed08dc5 100644
--- a/docker/diffusers-pytorch-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-cuda/Dockerfile
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
- "torch<2.5.0" \
+ torch \
torchvision \
torchaudio \
invisible_watermark && \
diff --git a/docker/diffusers-pytorch-minimum-cuda/Dockerfile b/docker/diffusers-pytorch-minimum-cuda/Dockerfile
new file mode 100644
index 000000000000..57ca7657acf1
--- /dev/null
+++ b/docker/diffusers-pytorch-minimum-cuda/Dockerfile
@@ -0,0 +1,53 @@
+FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV MINIMUM_SUPPORTED_TORCH_VERSION="2.1.0"
+ENV MINIMUM_SUPPORTED_TORCHVISION_VERSION="0.16.0"
+ENV MINIMUM_SUPPORTED_TORCHAUDIO_VERSION="2.1.0"
+
+RUN apt-get -y update \
+ && apt-get install -y software-properties-common \
+ && add-apt-repository ppa:deadsnakes/ppa
+
+RUN apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.10 \
+ python3.10-dev \
+ python3-pip \
+ python3.10-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.10 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
+ python3.10 -m uv pip install --no-cache-dir \
+ torch==$MINIMUM_SUPPORTED_TORCH_VERSION \
+ torchvision==$MINIMUM_SUPPORTED_TORCHVISION_VERSION \
+ torchaudio==$MINIMUM_SUPPORTED_TORCHAUDIO_VERSION \
+ invisible_watermark && \
+ python3.10 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ hf_transfer \
+ Jinja2 \
+ librosa \
+ numpy==1.26.4 \
+ scipy \
+ tensorboard \
+ transformers \
+ hf_transfer
+
+CMD ["/bin/bash"]
diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile
index 356445a6d173..1693eb293024 100644
--- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m pip install --no-cache-dir \
- "torch<2.5.0" \
+ torch \
torchvision \
torchaudio \
invisible_watermark && \
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 58218c0272bd..d39b5a52d2fe 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -48,13 +48,15 @@
- local: using-diffusers/inpaint
title: Inpainting
- local: using-diffusers/text-img2vid
- title: Text or image-to-video
+ title: Video generation
- local: using-diffusers/depth2img
title: Depth-to-image
title: Generative tasks
- sections:
- local: using-diffusers/overview_techniques
title: Overview
+ - local: using-diffusers/create_a_server
+ title: Create a server
- local: training/distributed_inference
title: Distributed inference
- local: using-diffusers/merge_loras
@@ -74,9 +76,21 @@
- local: advanced_inference/outpaint
title: Outpainting
title: Advanced inference
+- sections:
+ - local: hybrid_inference/overview
+ title: Overview
+ - local: hybrid_inference/vae_decode
+ title: VAE Decode
+ - local: hybrid_inference/vae_encode
+ title: VAE Encode
+ - local: hybrid_inference/api_reference
+ title: API Reference
+ title: Hybrid Inference
- sections:
- local: using-diffusers/cogvideox
title: CogVideoX
+ - local: using-diffusers/consisid
+ title: ConsisID
- local: using-diffusers/sdxl
title: Stable Diffusion XL
- local: using-diffusers/sdxl_turbo
@@ -85,6 +99,8 @@
title: Kandinsky
- local: using-diffusers/ip_adapter
title: IP-Adapter
+ - local: using-diffusers/omnigen
+ title: OmniGen
- local: using-diffusers/pag
title: PAG
- local: using-diffusers/controlnet
@@ -155,6 +171,12 @@
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
+ - local: quantization/gguf
+ title: gguf
+ - local: quantization/torchao
+ title: torchao
+ - local: quantization/quanto
+ title: quanto
title: Quantization Methods
- sections:
- local: optimization/fp16
@@ -173,6 +195,8 @@
title: TGATE
- local: optimization/xdit
title: xDiT
+ - local: optimization/para_attn
+ title: ParaAttention
- sections:
- local: using-diffusers/stable_diffusion_jax_how_to
title: JAX/Flax
@@ -188,6 +212,8 @@
title: Metal Performance Shaders (MPS)
- local: optimization/habana
title: Habana Gaudi
+ - local: optimization/neuron
+ title: AWS Neuron
title: Optimized hardware
title: Accelerate inference and reduce memory
- sections:
@@ -230,6 +256,8 @@
title: Textual Inversion
- local: api/loaders/unet
title: UNet
+ - local: api/loaders/transformer_sd3
+ title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
title: Loaders
@@ -248,36 +276,60 @@
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
+ - local: api/models/controlnet_union
+ title: ControlNetUnionModel
title: ControlNets
- sections:
+ - local: api/models/allegro_transformer3d
+ title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
+ - local: api/models/consisid_transformer3d
+ title: ConsisIDTransformer3DModel
- local: api/models/cogview3plus_transformer2d
title: CogView3PlusTransformer2DModel
+ - local: api/models/cogview4_transformer2d
+ title: CogView4Transformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
+ - local: api/models/easyanimate_transformer3d
+ title: EasyAnimateTransformer3DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
+ - local: api/models/hunyuan_video_transformer_3d
+ title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
+ - local: api/models/lumina2_transformer2d
+ title: Lumina2Transformer2DModel
+ - local: api/models/ltx_video_transformer3d
+ title: LTXVideoTransformer3DModel
+ - local: api/models/mochi_transformer3d
+ title: MochiTransformer3DModel
+ - local: api/models/omnigen_transformer
+ title: OmniGenTransformer2DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
title: PriorTransformer
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
+ - local: api/models/sana_transformer2d
+ title: SanaTransformer2DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/transformer2d
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
+ - local: api/models/wan_transformer_3d
+ title: WanTransformer3DModel
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
@@ -298,10 +350,24 @@
- sections:
- local: api/models/autoencoderkl
title: AutoencoderKL
+ - local: api/models/autoencoderkl_allegro
+ title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
+ - local: api/models/autoencoder_kl_hunyuan_video
+ title: AutoencoderKLHunyuanVideo
+ - local: api/models/autoencoderkl_ltx_video
+ title: AutoencoderKLLTXVideo
+ - local: api/models/autoencoderkl_magvit
+ title: AutoencoderKLMagvit
+ - local: api/models/autoencoderkl_mochi
+ title: AutoencoderKLMochi
+ - local: api/models/autoencoder_kl_wan
+ title: AutoencoderKLWan
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
+ - local: api/models/autoencoder_dc
+ title: AutoencoderDC
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck
@@ -316,6 +382,8 @@
sections:
- local: api/pipelines/overview
title: Overview
+ - local: api/pipelines/allegro
+ title: Allegro
- local: api/pipelines/amused
title: aMUSEd
- local: api/pipelines/animatediff
@@ -336,6 +404,10 @@
title: CogVideoX
- local: api/pipelines/cogview3
title: CogView3
+ - local: api/pipelines/cogview4
+ title: CogView4
+ - local: api/pipelines/consisid
+ title: ConsisID
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
@@ -352,6 +424,8 @@
title: ControlNet-XS
- local: api/pipelines/controlnetxs_sdxl
title: ControlNet-XS with Stable Diffusion XL
+ - local: api/pipelines/controlnet_union
+ title: ControlNetUnion
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
@@ -364,10 +438,16 @@
title: DiffEdit
- local: api/pipelines/dit
title: DiT
+ - local: api/pipelines/easyanimate
+ title: EasyAnimate
- local: api/pipelines/flux
title: Flux
+ - local: api/pipelines/control_flux_inpaint
+ title: FluxControlInpaint
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
+ - local: api/pipelines/hunyuan_video
+ title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/pix2pix
@@ -388,14 +468,22 @@
title: Latte
- local: api/pipelines/ledits_pp
title: LEDITS++
+ - local: api/pipelines/ltx_video
+ title: LTXVideo
+ - local: api/pipelines/lumina2
+ title: Lumina 2.0
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
title: Marigold
+ - local: api/pipelines/mochi
+ title: Mochi
- local: api/pipelines/panorama
title: MultiDiffusion
- local: api/pipelines/musicldm
title: MusicLDM
+ - local: api/pipelines/omnigen
+ title: OmniGen
- local: api/pipelines/pag
title: PAG
- local: api/pipelines/paint_by_example
@@ -406,6 +494,10 @@
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
+ - local: api/pipelines/sana
+ title: Sana
+ - local: api/pipelines/sana_sprint
+ title: Sana Sprint
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
@@ -466,6 +558,8 @@
title: UniDiffuser
- local: api/pipelines/value_guided_sampling
title: Value-guided sampling
+ - local: api/pipelines/wan
+ title: Wan
- local: api/pipelines/wuerstchen
title: Wuerstchen
title: Pipelines
@@ -475,6 +569,10 @@
title: Overview
- local: api/schedulers/cm_stochastic_iterative
title: CMStochasticIterativeScheduler
+ - local: api/schedulers/ddim_cogvideox
+ title: CogVideoXDDIMScheduler
+ - local: api/schedulers/multistep_dpm_solver_cogvideox
+ title: CogVideoXDPMScheduler
- local: api/schedulers/consistency_decoder
title: ConsistencyDecoderScheduler
- local: api/schedulers/cosine_dpm
@@ -544,6 +642,8 @@
title: Attention Processor
- local: api/activations
title: Custom activation functions
+ - local: api/cache
+ title: Caching methods
- local: api/normalization
title: Custom normalization layers
- local: api/utilities
diff --git a/docs/source/en/api/activations.md b/docs/source/en/api/activations.md
index 3bef28a5ab0d..140a2ae1a1b2 100644
--- a/docs/source/en/api/activations.md
+++ b/docs/source/en/api/activations.md
@@ -25,3 +25,16 @@ Customized activation functions for supporting various models in 🤗 Diffusers.
## ApproximateGELU
[[autodoc]] models.activations.ApproximateGELU
+
+
+## SwiGLU
+
+[[autodoc]] models.activations.SwiGLU
+
+## FP32SiLU
+
+[[autodoc]] models.activations.FP32SiLU
+
+## LinearActivation
+
+[[autodoc]] models.activations.LinearActivation
diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md
index 5b1f0be72ae6..638ecb973e5d 100644
--- a/docs/source/en/api/attnprocessor.md
+++ b/docs/source/en/api/attnprocessor.md
@@ -15,40 +15,152 @@ specific language governing permissions and limitations under the License.
An attention processor is a class for applying different types of attention mechanisms.
## AttnProcessor
+
[[autodoc]] models.attention_processor.AttnProcessor
-## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0
-## AttnAddedKVProcessor
[[autodoc]] models.attention_processor.AttnAddedKVProcessor
-## AttnAddedKVProcessor2_0
[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0
+[[autodoc]] models.attention_processor.AttnProcessorNPU
+
+[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
+
+## Allegro
+
+[[autodoc]] models.attention_processor.AllegroAttnProcessor2_0
+
+## AuraFlow
+
+[[autodoc]] models.attention_processor.AuraFlowAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.FusedAuraFlowAttnProcessor2_0
+
+## CogVideoX
+
+[[autodoc]] models.attention_processor.CogVideoXAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.FusedCogVideoXAttnProcessor2_0
+
## CrossFrameAttnProcessor
+
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
-## CustomDiffusionAttnProcessor
+## Custom Diffusion
+
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
-## CustomDiffusionAttnProcessor2_0
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
-## CustomDiffusionXFormersAttnProcessor
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
-## FusedAttnProcessor2_0
-[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
+## Flux
+
+[[autodoc]] models.attention_processor.FluxAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.FusedFluxAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.FluxSingleAttnProcessor2_0
+
+## Hunyuan
+
+[[autodoc]] models.attention_processor.HunyuanAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.FusedHunyuanAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.PAGHunyuanAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.PAGCFGHunyuanAttnProcessor2_0
+
+## IdentitySelfAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.PAGIdentitySelfAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0
+
+## IP-Adapter
+
+[[autodoc]] models.attention_processor.IPAdapterAttnProcessor
+
+[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0
+
+## JointAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.JointAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.PAGJointAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.PAGCFGJointAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.FusedJointAttnProcessor2_0
+
+## LoRA
+
+[[autodoc]] models.attention_processor.LoRAAttnProcessor
+
+[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
+
+[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
+
+## Lumina-T2X
+
+[[autodoc]] models.attention_processor.LuminaAttnProcessor2_0
+
+## Mochi
+
+[[autodoc]] models.attention_processor.MochiAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.MochiVaeAttnProcessor2_0
+
+## Sana
+
+[[autodoc]] models.attention_processor.SanaLinearAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.SanaMultiscaleAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0
+
+## Stable Audio
+
+[[autodoc]] models.attention_processor.StableAudioAttnProcessor2_0
## SlicedAttnProcessor
+
[[autodoc]] models.attention_processor.SlicedAttnProcessor
-## SlicedAttnAddedKVProcessor
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
## XFormersAttnProcessor
+
[[autodoc]] models.attention_processor.XFormersAttnProcessor
-## AttnProcessorNPU
-[[autodoc]] models.attention_processor.AttnProcessorNPU
+[[autodoc]] models.attention_processor.XFormersAttnAddedKVProcessor
+
+## XLAFlashAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0
+
+## XFormersJointAttnProcessor
+
+[[autodoc]] models.attention_processor.XFormersJointAttnProcessor
+
+## IPAdapterXFormersAttnProcessor
+
+[[autodoc]] models.attention_processor.IPAdapterXFormersAttnProcessor
+
+## FluxIPAdapterJointAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.FluxIPAdapterJointAttnProcessor2_0
+
+
+## XLAFluxFlashAttnProcessor2_0
+
+[[autodoc]] models.attention_processor.XLAFluxFlashAttnProcessor2_0
\ No newline at end of file
diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md
new file mode 100644
index 000000000000..a6aa5445a845
--- /dev/null
+++ b/docs/source/en/api/cache.md
@@ -0,0 +1,82 @@
+
+
+# Caching methods
+
+## Pyramid Attention Broadcast
+
+[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
+
+Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
+
+Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
+
+```python
+import torch
+from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
+
+pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
+# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
+# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
+# poorer quality of generated videos.
+config = PyramidAttentionBroadcastConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(100, 800),
+ current_timestep_callback=lambda: pipe.current_timestep,
+)
+pipe.transformer.enable_cache(config)
+```
+
+## Faster Cache
+
+[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
+
+FasterCache is a method that speeds up inference in diffusion transformers by:
+- Reusing attention states between successive inference steps, due to high similarity between them
+- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
+
+```python
+import torch
+from diffusers import CogVideoXPipeline, FasterCacheConfig
+
+pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 681),
+ current_timestep_callback=lambda: pipe.current_timestep,
+ attention_weight_callback=lambda _: 0.3,
+ unconditional_batch_skip_range=5,
+ unconditional_batch_timestep_skip_range=(-1, 781),
+ tensor_format="BFCHW",
+)
+pipe.transformer.enable_cache(config)
+```
+
+### CacheMixin
+
+[[autodoc]] CacheMixin
+
+### PyramidAttentionBroadcastConfig
+
+[[autodoc]] PyramidAttentionBroadcastConfig
+
+[[autodoc]] apply_pyramid_attention_broadcast
+
+### FasterCacheConfig
+
+[[autodoc]] FasterCacheConfig
+
+[[autodoc]] apply_faster_cache
diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md
index a10f30ef8e5b..946a8b1af875 100644
--- a/docs/source/en/api/loaders/ip_adapter.md
+++ b/docs/source/en/api/loaders/ip_adapter.md
@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
+## SD3IPAdapterMixin
+
+[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
+ - all
+ - is_ip_adapter_active
+
## IPAdapterMaskProcessor
[[autodoc]] image_processor.IPAdapterMaskProcessor
\ No newline at end of file
diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md
index 2060a1eefd52..58611a61c25d 100644
--- a/docs/source/en/api/loaders/lora.md
+++ b/docs/source/en/api/loaders/lora.md
@@ -17,6 +17,13 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
+- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
+- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
+- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
+- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
+- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
+- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
+- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
@@ -38,6 +45,34 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
+## FluxLoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
+
+## CogVideoXLoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
+
+## Mochi1LoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
+
+## LTXVideoLoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin
+
+## SanaLoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
+
+## HunyuanVideoLoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
+
+## Lumina2LoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
+
## AmusedLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md
new file mode 100644
index 000000000000..4fc9603054b4
--- /dev/null
+++ b/docs/source/en/api/loaders/transformer_sd3.md
@@ -0,0 +1,29 @@
+
+
+# SD3Transformer2D
+
+This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
+
+The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
+
+
+
+To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
+
+
+
+## SD3Transformer2DLoadersMixin
+
+[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
+ - all
+ - _load_ip_adapter_weights
\ No newline at end of file
diff --git a/docs/source/en/api/models/allegro_transformer3d.md b/docs/source/en/api/models/allegro_transformer3d.md
new file mode 100644
index 000000000000..7b035cd05535
--- /dev/null
+++ b/docs/source/en/api/models/allegro_transformer3d.md
@@ -0,0 +1,30 @@
+
+
+# AllegroTransformer3DModel
+
+A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AllegroTransformer3DModel
+
+transformer = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+```
+
+## AllegroTransformer3DModel
+
+[[autodoc]] AllegroTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/autoencoder_dc.md b/docs/source/en/api/models/autoencoder_dc.md
new file mode 100644
index 000000000000..6f86150eb744
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_dc.md
@@ -0,0 +1,72 @@
+
+
+# AutoencoderDC
+
+The 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\*, Han Cai\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.
+
+The abstract from the paper is:
+
+*We present Deep Compression Autoencoder (DC-AE), a new family of autoencoder models for accelerating high-resolution diffusion models. Existing autoencoder models have demonstrated impressive results at a moderate spatial compression ratio (e.g., 8x), but fail to maintain satisfactory reconstruction accuracy for high spatial compression ratios (e.g., 64x). We address this challenge by introducing two key techniques: (1) Residual Autoencoding, where we design our models to learn residuals based on the space-to-channel transformed features to alleviate the optimization difficulty of high spatial-compression autoencoders; (2) Decoupled High-Resolution Adaptation, an efficient decoupled three-phases training strategy for mitigating the generalization penalty of high spatial-compression autoencoders. With these designs, we improve the autoencoder's spatial compression ratio up to 128 while maintaining the reconstruction quality. Applying our DC-AE to latent diffusion models, we achieve significant speedup without accuracy drop. For example, on ImageNet 512x512, our DC-AE provides 19.1x inference speedup and 17.9x training speedup on H100 GPU for UViT-H while achieving a better FID, compared with the widely used SD-VAE-f8 autoencoder. Our code is available at [this https URL](https://github.com/mit-han-lab/efficientvit).*
+
+The following DCAE models are released and supported in Diffusers.
+
+| Diffusers format | Original format |
+|:----------------:|:---------------:|
+| [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0)
+| [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0)
+| [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0)
+| [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0)
+| [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0)
+| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
+| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)
+
+This model was contributed by [lawrence-cj](https://github.com/lawrence-cj).
+
+Load a model in Diffusers format with [`~ModelMixin.from_pretrained`].
+
+```python
+from diffusers import AutoencoderDC
+
+ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
+```
+
+## Load a model in Diffusers via `from_single_file`
+
+```python
+from difusers import AutoencoderDC
+
+ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
+model = AutoencoderDC.from_single_file(ckpt_path)
+
+```
+
+The `AutoencoderDC` model has `in` and `mix` single file checkpoint variants that have matching checkpoint keys, but use different scaling factors. It is not possible for Diffusers to automatically infer the correct config file to use with the model based on just the checkpoint and will default to configuring the model using the `mix` variant config file. To override the automatically determined config, please use the `config` argument when using single file loading with `in` variant checkpoints.
+
+```python
+from diffusers import AutoencoderDC
+
+ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors"
+model = AutoencoderDC.from_single_file(ckpt_path, config="mit-han-lab/dc-ae-f128c512-in-1.0-diffusers")
+```
+
+
+## AutoencoderDC
+
+[[autodoc]] AutoencoderDC
+ - encode
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
+
diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md
new file mode 100644
index 000000000000..33dff5b903cd
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md
@@ -0,0 +1,32 @@
+
+
+# AutoencoderKLHunyuanVideo
+
+The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLHunyuanVideo
+
+vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="vae", torch_dtype=torch.float16)
+```
+
+## AutoencoderKLHunyuanVideo
+
+[[autodoc]] AutoencoderKLHunyuanVideo
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoder_kl_wan.md b/docs/source/en/api/models/autoencoder_kl_wan.md
new file mode 100644
index 000000000000..43165c8edf7a
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_kl_wan.md
@@ -0,0 +1,32 @@
+
+
+# AutoencoderKLWan
+
+The 3D variational autoencoder (VAE) model with KL loss used in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLWan
+
+vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
+```
+
+## AutoencoderKLWan
+
+[[autodoc]] AutoencoderKLWan
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoderkl_allegro.md b/docs/source/en/api/models/autoencoderkl_allegro.md
new file mode 100644
index 000000000000..fd9d10d5724b
--- /dev/null
+++ b/docs/source/en/api/models/autoencoderkl_allegro.md
@@ -0,0 +1,37 @@
+
+
+# AutoencoderKLAllegro
+
+The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLAllegro
+
+vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda")
+```
+
+## AutoencoderKLAllegro
+
+[[autodoc]] AutoencoderKLAllegro
+ - decode
+ - encode
+ - all
+
+## AutoencoderKLOutput
+
+[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoderkl_ltx_video.md b/docs/source/en/api/models/autoencoderkl_ltx_video.md
new file mode 100644
index 000000000000..fbdb11e29cdd
--- /dev/null
+++ b/docs/source/en/api/models/autoencoderkl_ltx_video.md
@@ -0,0 +1,37 @@
+
+
+# AutoencoderKLLTXVideo
+
+The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLLTXVideo
+
+vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
+```
+
+## AutoencoderKLLTXVideo
+
+[[autodoc]] AutoencoderKLLTXVideo
+ - decode
+ - encode
+ - all
+
+## AutoencoderKLOutput
+
+[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoderkl_magvit.md b/docs/source/en/api/models/autoencoderkl_magvit.md
new file mode 100644
index 000000000000..7c1060ddd435
--- /dev/null
+++ b/docs/source/en/api/models/autoencoderkl_magvit.md
@@ -0,0 +1,37 @@
+
+
+# AutoencoderKLMagvit
+
+The 3D variational autoencoder (VAE) model with KL loss used in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLMagvit
+
+vae = AutoencoderKLMagvit.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="vae", torch_dtype=torch.float16).to("cuda")
+```
+
+## AutoencoderKLMagvit
+
+[[autodoc]] AutoencoderKLMagvit
+ - decode
+ - encode
+ - all
+
+## AutoencoderKLOutput
+
+[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoderkl_mochi.md b/docs/source/en/api/models/autoencoderkl_mochi.md
new file mode 100644
index 000000000000..9747de4af937
--- /dev/null
+++ b/docs/source/en/api/models/autoencoderkl_mochi.md
@@ -0,0 +1,32 @@
+
+
+# AutoencoderKLMochi
+
+The 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLMochi
+
+vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32).to("cuda")
+```
+
+## AutoencoderKLMochi
+
+[[autodoc]] AutoencoderKLMochi
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md
index 8c8baae7b537..30556ef7be3f 100644
--- a/docs/source/en/api/models/cogvideox_transformer3d.md
+++ b/docs/source/en/api/models/cogvideox_transformer3d.md
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import CogVideoXTransformer3DModel
-vae = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-2b", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
+transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-2b", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## CogVideoXTransformer3DModel
diff --git a/docs/source/en/api/models/cogview3plus_transformer2d.md b/docs/source/en/api/models/cogview3plus_transformer2d.md
index 16f71a58cfb4..7d022da79314 100644
--- a/docs/source/en/api/models/cogview3plus_transformer2d.md
+++ b/docs/source/en/api/models/cogview3plus_transformer2d.md
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import CogView3PlusTransformer2DModel
-vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+transformer = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## CogView3PlusTransformer2DModel
diff --git a/docs/source/en/api/models/cogview4_transformer2d.md b/docs/source/en/api/models/cogview4_transformer2d.md
new file mode 100644
index 000000000000..4bf14bdd4991
--- /dev/null
+++ b/docs/source/en/api/models/cogview4_transformer2d.md
@@ -0,0 +1,30 @@
+
+
+# CogView4Transformer2DModel
+
+A Diffusion Transformer model for 2D data from [CogView4]()
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import CogView4Transformer2DModel
+
+transformer = CogView4Transformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+```
+
+## CogView4Transformer2DModel
+
+[[autodoc]] CogView4Transformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/consisid_transformer3d.md b/docs/source/en/api/models/consisid_transformer3d.md
new file mode 100644
index 000000000000..bca03c099b1d
--- /dev/null
+++ b/docs/source/en/api/models/consisid_transformer3d.md
@@ -0,0 +1,30 @@
+
+
+# ConsisIDTransformer3DModel
+
+A Diffusion Transformer model for 3D data from [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) was introduced in [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://arxiv.org/pdf/2411.17440) by Peking University & University of Rochester & etc.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import ConsisIDTransformer3DModel
+
+transformer = ConsisIDTransformer3DModel.from_pretrained("BestWishYsh/ConsisID-preview", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+```
+
+## ConsisIDTransformer3DModel
+
+[[autodoc]] ConsisIDTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/controlnet.md b/docs/source/en/api/models/controlnet.md
index 966a0e53b496..5d4cac6658cc 100644
--- a/docs/source/en/api/models/controlnet.md
+++ b/docs/source/en/api/models/controlnet.md
@@ -39,7 +39,7 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
## ControlNetOutput
-[[autodoc]] models.controlnet.ControlNetOutput
+[[autodoc]] models.controlnets.controlnet.ControlNetOutput
## FlaxControlNetModel
@@ -47,4 +47,4 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
## FlaxControlNetOutput
-[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
+[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput
diff --git a/docs/source/en/api/models/controlnet_sd3.md b/docs/source/en/api/models/controlnet_sd3.md
index 59db64546fa2..78564d238eea 100644
--- a/docs/source/en/api/models/controlnet_sd3.md
+++ b/docs/source/en/api/models/controlnet_sd3.md
@@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di
## SD3ControlNetOutput
-[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
+[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput
diff --git a/docs/source/en/api/models/controlnet_union.md b/docs/source/en/api/models/controlnet_union.md
new file mode 100644
index 000000000000..9c0d86984549
--- /dev/null
+++ b/docs/source/en/api/models/controlnet_union.md
@@ -0,0 +1,35 @@
+
+
+# ControlNetUnionModel
+
+ControlNetUnionModel is an implementation of ControlNet for Stable Diffusion XL.
+
+The ControlNet model was introduced in [ControlNetPlus](https://github.com/xinsir6/ControlNetPlus) by xinsir6. It supports multiple conditioning inputs without increasing computation.
+
+*We design a new architecture that can support 10+ control types in condition text-to-image generation and can generate high resolution images visually comparable with midjourney. The network is based on the original ControlNet architecture, we propose two new modules to: 1 Extend the original ControlNet to support different image conditions using the same network parameter. 2 Support multiple conditions input without increasing computation offload, which is especially important for designers who want to edit image in detail, different conditions use the same condition encoder, without adding extra computations or parameters.*
+
+## Loading
+
+By default the [`ControlNetUnionModel`] should be loaded with [`~ModelMixin.from_pretrained`].
+
+```py
+from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel
+
+controlnet = ControlNetUnionModel.from_pretrained("xinsir/controlnet-union-sdxl-1.0")
+pipe = StableDiffusionXLControlNetUnionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet)
+```
+
+## ControlNetUnionModel
+
+[[autodoc]] ControlNetUnionModel
+
diff --git a/docs/source/en/api/models/easyanimate_transformer3d.md b/docs/source/en/api/models/easyanimate_transformer3d.md
new file mode 100644
index 000000000000..66670eb632d4
--- /dev/null
+++ b/docs/source/en/api/models/easyanimate_transformer3d.md
@@ -0,0 +1,30 @@
+
+
+# EasyAnimateTransformer3DModel
+
+A Diffusion Transformer model for 3D data from [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import EasyAnimateTransformer3DModel
+
+transformer = EasyAnimateTransformer3DModel.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
+```
+
+## EasyAnimateTransformer3DModel
+
+[[autodoc]] EasyAnimateTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md
new file mode 100644
index 000000000000..522d0eb0479d
--- /dev/null
+++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md
@@ -0,0 +1,30 @@
+
+
+# HunyuanVideoTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import HunyuanVideoTransformer3DModel
+
+transformer = HunyuanVideoTransformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## HunyuanVideoTransformer3DModel
+
+[[autodoc]] HunyuanVideoTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/ltx_video_transformer3d.md b/docs/source/en/api/models/ltx_video_transformer3d.md
new file mode 100644
index 000000000000..fe2664cf685c
--- /dev/null
+++ b/docs/source/en/api/models/ltx_video_transformer3d.md
@@ -0,0 +1,30 @@
+
+
+# LTXVideoTransformer3DModel
+
+A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import LTXVideoTransformer3DModel
+
+transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
+```
+
+## LTXVideoTransformer3DModel
+
+[[autodoc]] LTXVideoTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/lumina2_transformer2d.md b/docs/source/en/api/models/lumina2_transformer2d.md
new file mode 100644
index 000000000000..0d7c0585dcd5
--- /dev/null
+++ b/docs/source/en/api/models/lumina2_transformer2d.md
@@ -0,0 +1,30 @@
+
+
+# Lumina2Transformer2DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import Lumina2Transformer2DModel
+
+transformer = Lumina2Transformer2DModel.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## Lumina2Transformer2DModel
+
+[[autodoc]] Lumina2Transformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/mochi_transformer3d.md b/docs/source/en/api/models/mochi_transformer3d.md
new file mode 100644
index 000000000000..6c8e464feded
--- /dev/null
+++ b/docs/source/en/api/models/mochi_transformer3d.md
@@ -0,0 +1,30 @@
+
+
+# MochiTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import MochiTransformer3DModel
+
+transformer = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
+```
+
+## MochiTransformer3DModel
+
+[[autodoc]] MochiTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/omnigen_transformer.md b/docs/source/en/api/models/omnigen_transformer.md
new file mode 100644
index 000000000000..78d29fdab5e4
--- /dev/null
+++ b/docs/source/en/api/models/omnigen_transformer.md
@@ -0,0 +1,30 @@
+
+
+# OmniGenTransformer2DModel
+
+A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).
+
+The abstract from the paper is:
+
+*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
+
+```python
+import torch
+from diffusers import OmniGenTransformer2DModel
+
+transformer = OmniGenTransformer2DModel.from_pretrained("Shitao/OmniGen-v1-diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## OmniGenTransformer2DModel
+
+[[autodoc]] OmniGenTransformer2DModel
diff --git a/docs/source/en/api/models/sana_transformer2d.md b/docs/source/en/api/models/sana_transformer2d.md
new file mode 100644
index 000000000000..269aefd7ff69
--- /dev/null
+++ b/docs/source/en/api/models/sana_transformer2d.md
@@ -0,0 +1,34 @@
+
+
+# SanaTransformer2DModel
+
+A Diffusion Transformer model for 2D data from [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) was introduced from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
+
+The abstract from the paper is:
+
+*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import SanaTransformer2DModel
+
+transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## SanaTransformer2DModel
+
+[[autodoc]] SanaTransformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/wan_transformer_3d.md b/docs/source/en/api/models/wan_transformer_3d.md
new file mode 100644
index 000000000000..56015c4c07f1
--- /dev/null
+++ b/docs/source/en/api/models/wan_transformer_3d.md
@@ -0,0 +1,30 @@
+
+
+# WanTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import WanTransformer3DModel
+
+transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## WanTransformer3DModel
+
+[[autodoc]] WanTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/normalization.md b/docs/source/en/api/normalization.md
index ef4b694a4d85..05ae92a28dc8 100644
--- a/docs/source/en/api/normalization.md
+++ b/docs/source/en/api/normalization.md
@@ -29,3 +29,43 @@ Customized normalization layers for supporting various models in 🤗 Diffusers.
## AdaGroupNorm
[[autodoc]] models.normalization.AdaGroupNorm
+
+## AdaLayerNormContinuous
+
+[[autodoc]] models.normalization.AdaLayerNormContinuous
+
+## RMSNorm
+
+[[autodoc]] models.normalization.RMSNorm
+
+## GlobalResponseNorm
+
+[[autodoc]] models.normalization.GlobalResponseNorm
+
+
+## LuminaLayerNormContinuous
+[[autodoc]] models.normalization.LuminaLayerNormContinuous
+
+## SD35AdaLayerNormZeroX
+[[autodoc]] models.normalization.SD35AdaLayerNormZeroX
+
+## AdaLayerNormZeroSingle
+[[autodoc]] models.normalization.AdaLayerNormZeroSingle
+
+## LuminaRMSNormZero
+[[autodoc]] models.normalization.LuminaRMSNormZero
+
+## LpNorm
+[[autodoc]] models.normalization.LpNorm
+
+## CogView3PlusAdaLayerNormZeroTextImage
+[[autodoc]] models.normalization.CogView3PlusAdaLayerNormZeroTextImage
+
+## CogVideoXLayerNormZero
+[[autodoc]] models.normalization.CogVideoXLayerNormZero
+
+## MochiRMSNormZero
+[[autodoc]] models.transformers.transformer_mochi.MochiRMSNormZero
+
+## MochiRMSNorm
+[[autodoc]] models.normalization.MochiRMSNorm
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/allegro.md b/docs/source/en/api/pipelines/allegro.md
new file mode 100644
index 000000000000..690f8096a0e4
--- /dev/null
+++ b/docs/source/en/api/pipelines/allegro.md
@@ -0,0 +1,79 @@
+
+
+# Allegro
+
+[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.
+
+The abstract from the paper is:
+
+*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`AllegroPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AllegroTransformer3DModel, AllegroPipeline
+from diffusers.utils import export_to_video
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "rhymes-ai/Allegro",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = AllegroTransformer3DModel.from_pretrained(
+ "rhymes-ai/Allegro",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = AllegroPipeline.from_pretrained(
+ "rhymes-ai/Allegro",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = (
+ "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
+ "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this "
+ "location might be a popular spot for docking fishing boats."
+)
+video = pipeline(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0]
+export_to_video(video, "harbor.mp4", fps=15)
+```
+
+## AllegroPipeline
+
+[[autodoc]] AllegroPipeline
+ - all
+ - __call__
+
+## AllegroPipelineOutput
+
+[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput
diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md
index 735901280362..ed5ced7dbbc7 100644
--- a/docs/source/en/api/pipelines/animatediff.md
+++ b/docs/source/en/api/pipelines/animatediff.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Text-to-Video Generation with AnimateDiff
+
+
+
+
## Overview
[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) by Yuwei Guo, Ceyuan Yang, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai.
@@ -803,7 +807,7 @@ FreeInit is not really free - the improved quality comes at the cost of extra co
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/attend_and_excite.md b/docs/source/en/api/pipelines/attend_and_excite.md
index fd8dd95fa1c3..953ab1bb7288 100644
--- a/docs/source/en/api/pipelines/attend_and_excite.md
+++ b/docs/source/en/api/pipelines/attend_and_excite.md
@@ -22,7 +22,7 @@ You can find additional information about Attend-and-Excite on the [project page
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/audioldm.md b/docs/source/en/api/pipelines/audioldm.md
index 95d41b9569f5..02fe2c779eee 100644
--- a/docs/source/en/api/pipelines/audioldm.md
+++ b/docs/source/en/api/pipelines/audioldm.md
@@ -37,7 +37,7 @@ During inference:
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/audioldm2.md b/docs/source/en/api/pipelines/audioldm2.md
index 9f2b7529d4bc..debd2c3433e4 100644
--- a/docs/source/en/api/pipelines/audioldm2.md
+++ b/docs/source/en/api/pipelines/audioldm2.md
@@ -60,7 +60,7 @@ The following example demonstrates how to construct good music and speech genera
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/aura_flow.md b/docs/source/en/api/pipelines/aura_flow.md
index aa5a04800e6f..5d58690505b3 100644
--- a/docs/source/en/api/pipelines/aura_flow.md
+++ b/docs/source/en/api/pipelines/aura_flow.md
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# AuraFlow
-AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3.md) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark.
+AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark.
It was developed by the Fal team and more details about it can be found in [this blog post](https://blog.fal.ai/auraflow/).
@@ -22,6 +22,73 @@ AuraFlow can be quite expensive to run on consumer hardware devices. However, yo
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`AuraFlowPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AuraFlowTransformer2DModel, AuraFlowPipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "fal/AuraFlow",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = AuraFlowTransformer2DModel.from_pretrained(
+ "fal/AuraFlow",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = AuraFlowPipeline.from_pretrained(
+ "fal/AuraFlow",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon"
+image = pipeline(prompt).images[0]
+image.save("auraflow.png")
+```
+
+Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:
+
+```py
+import torch
+from diffusers import (
+ AuraFlowPipeline,
+ GGUFQuantizationConfig,
+ AuraFlowTransformer2DModel,
+)
+
+transformer = AuraFlowTransformer2DModel.from_single_file(
+ "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ torch_dtype=torch.bfloat16,
+)
+
+pipeline = AuraFlowPipeline.from_pretrained(
+ "fal/AuraFlow-v0.3",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+)
+
+prompt = "a cute pony in a field of flowers"
+image = pipeline(prompt).images[0]
+image.save("auraflow.png")
+```
+
## AuraFlowPipeline
[[autodoc]] AuraFlowPipeline
diff --git a/docs/source/en/api/pipelines/blip_diffusion.md b/docs/source/en/api/pipelines/blip_diffusion.md
index b4504f6d6b19..15d17da8f07c 100644
--- a/docs/source/en/api/pipelines/blip_diffusion.md
+++ b/docs/source/en/api/pipelines/blip_diffusion.md
@@ -25,7 +25,7 @@ The original codebase can be found at [salesforce/LAVIS](https://github.com/sale
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md
index f0f4fd37e6d5..0de40f934548 100644
--- a/docs/source/en/api/pipelines/cogvideox.md
+++ b/docs/source/en/api/pipelines/cogvideox.md
@@ -15,6 +15,10 @@
# CogVideoX
+
+
+
+
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
The abstract from the paper is:
@@ -23,22 +27,38 @@ The abstract from the paper is:
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
-There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines:
-- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`.
-- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`.
+There are three official CogVideoX checkpoints for text-to-video and video-to-video.
+
+| checkpoints | recommended inference dtype |
+|:---:|:---:|
+| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
+| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
+| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |
+
+There are two official CogVideoX checkpoints available for image-to-video.
+
+| checkpoints | recommended inference dtype |
+|:---:|:---:|
+| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
+| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |
-There is one model available that can be used with the image-to-video CogVideoX pipeline:
-- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
+For the CogVideoX 1.5 series:
+- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution.
+- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16.
+- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.
-There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
-- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
-- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
+There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
+
+| checkpoints | recommended inference dtype |
+|:---:|:---:|
+| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
+| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |
## Inference
@@ -96,13 +116,46 @@ CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds o
- With enabling cpu offloading and tiling, memory usage is `11 GB`
- `pipe.vae.enable_slicing()`
-### Quantized inference
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
-[torchao](https://github.com/pytorch/ao) and [optimum-quanto](https://github.com/huggingface/optimum-quanto/) can be used to quantize the text encoder, transformer and VAE modules to lower the memory requirements. This makes it possible to run the model on a free-tier T4 Colab or lower VRAM GPUs!
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`CogVideoXPipeline`] for inference with bitsandbytes.
-It is also worth noting that torchao quantization is fully compatible with [torch.compile](/optimization/torch2.0#torchcompile), which allows for much faster inference speed. Additionally, models can be serialized and stored in a quantized datatype to save disk space with torchao. Find examples and benchmarks in the gists below.
-- [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897)
-- [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa)
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, CogVideoXTransformer3DModel, CogVideoXPipeline
+from diffusers.utils import export_to_video
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "THUDM/CogVideoX-2b",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = CogVideoXTransformer3DModel.from_pretrained(
+ "THUDM/CogVideoX-2b",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = CogVideoXPipeline.from_pretrained(
+ "THUDM/CogVideoX-2b",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
+video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+export_to_video(video, "ship.mp4", fps=8)
+```
## CogVideoXPipeline
diff --git a/docs/source/en/api/pipelines/cogview3.md b/docs/source/en/api/pipelines/cogview3.md
index 85a9cf91736f..277edca4cf33 100644
--- a/docs/source/en/api/pipelines/cogview3.md
+++ b/docs/source/en/api/pipelines/cogview3.md
@@ -23,7 +23,7 @@ The abstract from the paper is:
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/cogview4.md b/docs/source/en/api/pipelines/cogview4.md
new file mode 100644
index 000000000000..cc17c3c905fb
--- /dev/null
+++ b/docs/source/en/api/pipelines/cogview4.md
@@ -0,0 +1,34 @@
+
+
+# CogView4
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
+
+## CogView4Pipeline
+
+[[autodoc]] CogView4Pipeline
+ - all
+ - __call__
+
+## CogView4PipelineOutput
+
+[[autodoc]] pipelines.cogview4.pipeline_output.CogView4PipelineOutput
diff --git a/docs/source/en/api/pipelines/consisid.md b/docs/source/en/api/pipelines/consisid.md
new file mode 100644
index 000000000000..6a23f223a6ca
--- /dev/null
+++ b/docs/source/en/api/pipelines/consisid.md
@@ -0,0 +1,64 @@
+
+
+# ConsisID
+
+
+
+
+
+[Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://arxiv.org/abs/2411.17440) from Peking University & University of Rochester & etc, by Shenghai Yuan, Jinfa Huang, Xianyi He, Yunyang Ge, Yujun Shi, Liuhan Chen, Jiebo Luo, Li Yuan.
+
+The abstract from the paper is:
+
+*Identity-preserving text-to-video (IPT2V) generation aims to create high-fidelity videos with consistent human identity. It is an important task in video generation but remains an open problem for generative models. This paper pushes the technical frontier of IPT2V in two directions that have not been resolved in the literature: (1) A tuning-free pipeline without tedious case-by-case finetuning, and (2) A frequency-aware heuristic identity-preserving Diffusion Transformer (DiT)-based control scheme. To achieve these goals, we propose **ConsisID**, a tuning-free DiT-based controllable IPT2V model to keep human-**id**entity **consis**tent in the generated video. Inspired by prior findings in frequency analysis of vision/diffusion transformers, it employs identity-control signals in the frequency domain, where facial features can be decomposed into low-frequency global features (e.g., profile, proportions) and high-frequency intrinsic features (e.g., identity markers that remain unaffected by pose changes). First, from a low-frequency perspective, we introduce a global facial extractor, which encodes the reference image and facial key points into a latent space, generating features enriched with low-frequency information. These features are then integrated into the shallow layers of the network to alleviate training challenges associated with DiT. Second, from a high-frequency perspective, we design a local facial extractor to capture high-frequency details and inject them into the transformer blocks, enhancing the model's ability to preserve fine-grained features. To leverage the frequency information for identity preservation, we propose a hierarchical training strategy, transforming a vanilla pre-trained video generation model into an IPT2V model. Extensive experiments demonstrate that our frequency-aware heuristic scheme provides an optimal control solution for DiT-based models. Thanks to this scheme, our **ConsisID** achieves excellent results in generating high-quality, identity-preserving videos, making strides towards more effective IPT2V. The model weight of ConsID is publicly available at https://github.com/PKU-YuanGroup/ConsisID.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [SHYuanBest](https://github.com/SHYuanBest). The original codebase can be found [here](https://github.com/PKU-YuanGroup/ConsisID). The original weights can be found under [hf.co/BestWishYsh](https://huggingface.co/BestWishYsh).
+
+There are two official ConsisID checkpoints for identity-preserving text-to-video.
+
+| checkpoints | recommended inference dtype |
+|:---:|:---:|
+| [`BestWishYsh/ConsisID-preview`](https://huggingface.co/BestWishYsh/ConsisID-preview) | torch.bfloat16 |
+| [`BestWishYsh/ConsisID-1.5`](https://huggingface.co/BestWishYsh/ConsisID-preview) | torch.bfloat16 |
+
+### Memory optimization
+
+ConsisID requires about 44 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/SHYuanBest/bc4207c36f454f9e969adbb50eaf8258) script.
+
+| Feature (overlay the previous) | Max Memory Allocated | Max Memory Reserved |
+| :----------------------------- | :------------------- | :------------------ |
+| - | 37 GB | 44 GB |
+| enable_model_cpu_offload | 22 GB | 25 GB |
+| enable_sequential_cpu_offload | 16 GB | 22 GB |
+| vae.enable_slicing | 16 GB | 22 GB |
+| vae.enable_tiling | 5 GB | 7 GB |
+
+## ConsisIDPipeline
+
+[[autodoc]] ConsisIDPipeline
+
+ - all
+ - __call__
+
+## ConsisIDPipelineOutput
+
+[[autodoc]] pipelines.consisid.pipeline_output.ConsisIDPipelineOutput
diff --git a/docs/source/en/api/pipelines/control_flux_inpaint.md b/docs/source/en/api/pipelines/control_flux_inpaint.md
new file mode 100644
index 000000000000..3e8edb498766
--- /dev/null
+++ b/docs/source/en/api/pipelines/control_flux_inpaint.md
@@ -0,0 +1,93 @@
+
+
+# FluxControlInpaint
+
+
+
+
+
+FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.
+
+FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.
+
+| Control type | Developer | Link |
+| -------- | ---------- | ---- |
+| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
+| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
+
+
+
+
+Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
+
+
+
+```python
+import torch
+from diffusers import FluxControlInpaintPipeline
+from diffusers.models.transformers import FluxTransformer2DModel
+from transformers import T5EncoderModel
+from diffusers.utils import load_image, make_image_grid
+from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
+from PIL import Image
+import numpy as np
+
+pipe = FluxControlInpaintPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Depth-dev",
+ torch_dtype=torch.bfloat16,
+)
+# use following lines if you have GPU constraints
+# ---------------------------------------------------------------
+transformer = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
+)
+text_encoder_2 = T5EncoderModel.from_pretrained(
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
+)
+pipe.transformer = transformer
+pipe.text_encoder_2 = text_encoder_2
+pipe.enable_model_cpu_offload()
+# ---------------------------------------------------------------
+pipe.to("cuda")
+
+prompt = "a blue robot singing opera with human-like expressions"
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+head_mask = np.zeros_like(image)
+head_mask[65:580,300:642] = 255
+mask_image = Image.fromarray(head_mask)
+
+processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
+control_image = processor(image)[0].convert("RGB")
+
+output = pipe(
+ prompt=prompt,
+ image=image,
+ control_image=control_image,
+ mask_image=mask_image,
+ num_inference_steps=30,
+ strength=0.9,
+ guidance_scale=10.0,
+ generator=torch.Generator().manual_seed(42),
+).images[0]
+make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")
+```
+
+## FluxControlInpaintPipeline
+[[autodoc]] FluxControlInpaintPipeline
+ - all
+ - __call__
+
+
+## FluxPipelineOutput
+[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/controlnet.md b/docs/source/en/api/pipelines/controlnet.md
index 6b00902cf296..11f2c4f11f73 100644
--- a/docs/source/en/api/pipelines/controlnet.md
+++ b/docs/source/en/api/pipelines/controlnet.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# ControlNet
+
+
+
+
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
@@ -26,7 +30,7 @@ The original codebase can be found at [lllyasviel/ControlNet](https://github.com
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/controlnet_flux.md b/docs/source/en/api/pipelines/controlnet_flux.md
index 82454ae5e930..1bb15d7aabb2 100644
--- a/docs/source/en/api/pipelines/controlnet_flux.md
+++ b/docs/source/en/api/pipelines/controlnet_flux.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# ControlNet with Flux.1
+
+
+
+
FluxControlNetPipeline is an implementation of ControlNet for Flux.1.
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
@@ -42,7 +46,7 @@ XLabs ControlNets are also supported, which was contributed by the [XLabs team](
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/controlnet_hunyuandit.md b/docs/source/en/api/pipelines/controlnet_hunyuandit.md
index e702eb30b8b0..6776b88ab35f 100644
--- a/docs/source/en/api/pipelines/controlnet_hunyuandit.md
+++ b/docs/source/en/api/pipelines/controlnet_hunyuandit.md
@@ -26,7 +26,7 @@ This code is implemented by Tencent Hunyuan Team. You can find pre-trained check
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/controlnet_sd3.md b/docs/source/en/api/pipelines/controlnet_sd3.md
index bb91a43cbaef..cee52ef5d76e 100644
--- a/docs/source/en/api/pipelines/controlnet_sd3.md
+++ b/docs/source/en/api/pipelines/controlnet_sd3.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# ControlNet with Stable Diffusion 3
+
+
+
+
StableDiffusion3ControlNetPipeline is an implementation of ControlNet for Stable Diffusion 3.
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
@@ -28,6 +32,7 @@ This controlnet code is mainly implemented by [The InstantX Team](https://huggin
| ControlNet type | Developer | Link |
| -------- | ---------- | ---- |
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Canny) |
+| Depth | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Depth) |
| Pose | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Pose) |
| Tile | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Tile) |
| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |
@@ -35,7 +40,7 @@ This controlnet code is mainly implemented by [The InstantX Team](https://huggin
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/controlnet_sdxl.md b/docs/source/en/api/pipelines/controlnet_sdxl.md
index 2de7cbff6ebc..f299702297b4 100644
--- a/docs/source/en/api/pipelines/controlnet_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnet_sdxl.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# ControlNet with Stable Diffusion XL
+
+
+
+
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
@@ -32,7 +36,7 @@ If you don't see a checkpoint you're interested in, you can train your own SDXL
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/controlnet_union.md b/docs/source/en/api/pipelines/controlnet_union.md
new file mode 100644
index 000000000000..58ae19e778dd
--- /dev/null
+++ b/docs/source/en/api/pipelines/controlnet_union.md
@@ -0,0 +1,39 @@
+
+
+# ControlNetUnion
+
+
+
+
+
+ControlNetUnionModel is an implementation of ControlNet for Stable Diffusion XL.
+
+The ControlNet model was introduced in [ControlNetPlus](https://github.com/xinsir6/ControlNetPlus) by xinsir6. It supports multiple conditioning inputs without increasing computation.
+
+*We design a new architecture that can support 10+ control types in condition text-to-image generation and can generate high resolution images visually comparable with midjourney. The network is based on the original ControlNet architecture, we propose two new modules to: 1 Extend the original ControlNet to support different image conditions using the same network parameter. 2 Support multiple conditions input without increasing computation offload, which is especially important for designers who want to edit image in detail, different conditions use the same condition encoder, without adding extra computations or parameters.*
+
+
+## StableDiffusionXLControlNetUnionPipeline
+[[autodoc]] StableDiffusionXLControlNetUnionPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLControlNetUnionImg2ImgPipeline
+[[autodoc]] StableDiffusionXLControlNetUnionImg2ImgPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLControlNetUnionInpaintPipeline
+[[autodoc]] StableDiffusionXLControlNetUnionInpaintPipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/pipelines/controlnetxs.md b/docs/source/en/api/pipelines/controlnetxs.md
index 2d4ae7b8ce46..2eebcc6b74d3 100644
--- a/docs/source/en/api/pipelines/controlnetxs.md
+++ b/docs/source/en/api/pipelines/controlnetxs.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# ControlNet-XS
+
+
+
+
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
Like the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
@@ -26,7 +30,7 @@ This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/controlnetxs_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
index 31075c0ef96a..0862a5d79878 100644
--- a/docs/source/en/api/pipelines/controlnetxs_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
@@ -32,7 +32,7 @@ This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/dance_diffusion.md b/docs/source/en/api/pipelines/dance_diffusion.md
index efba3c3763a4..9b6e7b66e198 100644
--- a/docs/source/en/api/pipelines/dance_diffusion.md
+++ b/docs/source/en/api/pipelines/dance_diffusion.md
@@ -19,7 +19,7 @@ Dance Diffusion is the first in a suite of generative audio tools for producers
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/ddpm.md b/docs/source/en/api/pipelines/ddpm.md
index 81ddb5e0c051..0935f0bec79c 100644
--- a/docs/source/en/api/pipelines/ddpm.md
+++ b/docs/source/en/api/pipelines/ddpm.md
@@ -22,7 +22,7 @@ The original codebase can be found at [hohonathanho/diffusion](https://github.co
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/deepfloyd_if.md b/docs/source/en/api/pipelines/deepfloyd_if.md
index 00441980d802..162476619867 100644
--- a/docs/source/en/api/pipelines/deepfloyd_if.md
+++ b/docs/source/en/api/pipelines/deepfloyd_if.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# DeepFloyd IF
+
+
+
+
## Overview
DeepFloyd IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding.
diff --git a/docs/source/en/api/pipelines/dit.md b/docs/source/en/api/pipelines/dit.md
index 1d04458d9cb9..2ee45b631c77 100644
--- a/docs/source/en/api/pipelines/dit.md
+++ b/docs/source/en/api/pipelines/dit.md
@@ -22,7 +22,7 @@ The original codebase can be found at [facebookresearch/dit](https://github.com/
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/easyanimate.md b/docs/source/en/api/pipelines/easyanimate.md
new file mode 100644
index 000000000000..15d44a12b1b6
--- /dev/null
+++ b/docs/source/en/api/pipelines/easyanimate.md
@@ -0,0 +1,88 @@
+
+
+# EasyAnimate
+[EasyAnimate](https://github.com/aigc-apps/EasyAnimate) by Alibaba PAI.
+
+The description from it's GitHub page:
+*EasyAnimate is a pipeline based on the transformer architecture, designed for generating AI images and videos, and for training baseline models and Lora models for Diffusion Transformer. We support direct prediction from pre-trained EasyAnimate models, allowing for the generation of videos with various resolutions, approximately 6 seconds in length, at 8fps (EasyAnimateV5.1, 1 to 49 frames). Additionally, users can train their own baseline and Lora models for specific style transformations.*
+
+This pipeline was contributed by [bubbliiiing](https://github.com/bubbliiiing). The original codebase can be found [here](https://huggingface.co/alibaba-pai). The original weights can be found under [hf.co/alibaba-pai](https://huggingface.co/alibaba-pai).
+
+There are two official EasyAnimate checkpoints for text-to-video and video-to-video.
+
+| checkpoints | recommended inference dtype |
+|:---:|:---:|
+| [`alibaba-pai/EasyAnimateV5.1-12b-zh`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh) | torch.float16 |
+| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
+
+There is one official EasyAnimate checkpoints available for image-to-video and video-to-video.
+
+| checkpoints | recommended inference dtype |
+|:---:|:---:|
+| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
+
+There are two official EasyAnimate checkpoints available for control-to-video.
+
+| checkpoints | recommended inference dtype |
+|:---:|:---:|
+| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control) | torch.float16 |
+| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera) | torch.float16 |
+
+For the EasyAnimateV5.1 series:
+- Text-to-video (T2V) and Image-to-video (I2V) works for multiple resolutions. The width and height can vary from 256 to 1024.
+- Both T2V and I2V models support generation with 1~49 frames and work best at this value. Exporting videos at 8 FPS is recommended.
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`EasyAnimatePipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
+from diffusers.utils import export_to_video
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = EasyAnimateTransformer3DModel.from_pretrained(
+ "alibaba-pai/EasyAnimateV5.1-12b-zh",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = EasyAnimatePipeline.from_pretrained(
+ "alibaba-pai/EasyAnimateV5.1-12b-zh",
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "A cat walks on the grass, realistic style."
+negative_prompt = "bad detailed"
+video = pipeline(prompt=prompt, negative_prompt=negative_prompt, num_frames=49, num_inference_steps=30).frames[0]
+export_to_video(video, "cat.mp4", fps=8)
+```
+
+## EasyAnimatePipeline
+
+[[autodoc]] EasyAnimatePipeline
+ - all
+ - __call__
+
+## EasyAnimatePipelineOutput
+
+[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput
diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md
index 255c69c854bc..44f6096edfb3 100644
--- a/docs/source/en/api/pipelines/flux.md
+++ b/docs/source/en/api/pipelines/flux.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Flux
+
+
+
+
Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux).
@@ -22,12 +26,20 @@ Flux can be quite expensive to run on consumer hardware devices. However, you ca
-Flux comes in two variants:
+Flux comes in the following variants:
-* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`)
-* Guidance-distilled (`black-forest-labs/FLUX.1-dev`)
+| model type | model id |
+|:----------:|:--------:|
+| Timestep-distilled | [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) |
+| Guidance-distilled | [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev) |
+| Fill Inpainting/Outpainting (Guidance-distilled) | [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) |
+| Canny Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Canny-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
+| Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
+| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
+| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
+| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
-Both checkpoints have slightly difference usage which we detail below.
+All checkpoints have different usage which we detail below.
### Timestep-distilled
@@ -77,7 +89,345 @@ out = pipe(
out.save("image.png")
```
-## Running FP16 inference
+### Fill Inpainting/Outpainting
+
+* Flux Fill pipeline does not require `strength` as an input like regular inpainting pipelines.
+* It supports both inpainting and outpainting.
+
+```python
+import torch
+from diffusers import FluxFillPipeline
+from diffusers.utils import load_image
+
+image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png")
+mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png")
+
+repo_id = "black-forest-labs/FLUX.1-Fill-dev"
+pipe = FluxFillPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")
+
+image = pipe(
+ prompt="a white paper cup",
+ image=image,
+ mask_image=mask,
+ height=1632,
+ width=1232,
+ max_sequence_length=512,
+ generator=torch.Generator("cpu").manual_seed(0)
+).images[0]
+image.save(f"output.png")
+```
+
+### Canny Control
+
+**Note:** `black-forest-labs/Flux.1-Canny-dev` is _not_ a [`ControlNetModel`] model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
+
+```python
+# !pip install -U controlnet-aux
+import torch
+from controlnet_aux import CannyDetector
+from diffusers import FluxControlPipeline
+from diffusers.utils import load_image
+
+pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16).to("cuda")
+
+prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
+control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+processor = CannyDetector()
+control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
+
+image = pipe(
+ prompt=prompt,
+ control_image=control_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=50,
+ guidance_scale=30.0,
+).images[0]
+image.save("output.png")
+```
+
+Canny Control is also possible with a LoRA variant of this condition. The usage is as follows:
+
+```python
+# !pip install -U controlnet-aux
+import torch
+from controlnet_aux import CannyDetector
+from diffusers import FluxControlPipeline
+from diffusers.utils import load_image
+
+pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
+pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
+
+prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
+control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+processor = CannyDetector()
+control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
+
+image = pipe(
+ prompt=prompt,
+ control_image=control_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=50,
+ guidance_scale=30.0,
+).images[0]
+image.save("output.png")
+```
+
+### Depth Control
+
+**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
+
+```python
+# !pip install git+https://github.com/huggingface/image_gen_aux
+import torch
+from diffusers import FluxControlPipeline, FluxTransformer2DModel
+from diffusers.utils import load_image
+from image_gen_aux import DepthPreprocessor
+
+pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16).to("cuda")
+
+prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
+control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
+control_image = processor(control_image)[0].convert("RGB")
+
+image = pipe(
+ prompt=prompt,
+ control_image=control_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=30,
+ guidance_scale=10.0,
+ generator=torch.Generator().manual_seed(42),
+).images[0]
+image.save("output.png")
+```
+
+Depth Control is also possible with a LoRA variant of this condition. The usage is as follows:
+
+```python
+# !pip install git+https://github.com/huggingface/image_gen_aux
+import torch
+from diffusers import FluxControlPipeline, FluxTransformer2DModel
+from diffusers.utils import load_image
+from image_gen_aux import DepthPreprocessor
+
+pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
+pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora")
+
+prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
+control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
+control_image = processor(control_image)[0].convert("RGB")
+
+image = pipe(
+ prompt=prompt,
+ control_image=control_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=30,
+ guidance_scale=10.0,
+ generator=torch.Generator().manual_seed(42),
+).images[0]
+image.save("output.png")
+```
+
+### Redux
+
+* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation.
+* You can first use the `FluxPriorReduxPipeline` to get the `prompt_embeds` and `pooled_prompt_embeds`, and then feed them into the `FluxPipeline` for image-to-image generation.
+* When use `FluxPriorReduxPipeline` with a base pipeline, you can set `text_encoder=None` and `text_encoder_2=None` in the base pipeline, in order to save VRAM.
+
+```python
+import torch
+from diffusers import FluxPriorReduxPipeline, FluxPipeline
+from diffusers.utils import load_image
+device = "cuda"
+dtype = torch.bfloat16
+
+
+repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
+repo_base = "black-forest-labs/FLUX.1-dev"
+pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
+pipe = FluxPipeline.from_pretrained(
+ repo_base,
+ text_encoder=None,
+ text_encoder_2=None,
+ torch_dtype=torch.bfloat16
+).to(device)
+
+image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")
+pipe_prior_output = pipe_prior_redux(image)
+images = pipe(
+ guidance_scale=2.5,
+ num_inference_steps=50,
+ generator=torch.Generator("cpu").manual_seed(0),
+ **pipe_prior_output,
+).images
+images[0].save("flux-redux.png")
+```
+
+## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
+
+We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
+
+```py
+from diffusers import FluxControlPipeline
+from image_gen_aux import DepthPreprocessor
+from diffusers.utils import load_image
+from huggingface_hub import hf_hub_download
+import torch
+
+control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
+control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
+control_pipe.load_lora_weights(
+ hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
+)
+control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
+control_pipe.enable_model_cpu_offload()
+
+prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
+control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
+control_image = processor(control_image)[0].convert("RGB")
+
+image = control_pipe(
+ prompt=prompt,
+ control_image=control_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=8,
+ guidance_scale=10.0,
+ generator=torch.Generator().manual_seed(42),
+).images[0]
+image.save("output.png")
+```
+
+## Note about `unload_lora_weights()` when using Flux LoRAs
+
+When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).
+
+## IP-Adapter
+
+
+
+Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
+
+
+
+An IP-Adapter lets you prompt Flux with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images.
+
+```python
+import torch
+from diffusers import FluxPipeline
+from diffusers.utils import load_image
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+).to("cuda")
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg").resize((1024, 1024))
+
+pipe.load_ip_adapter(
+ "XLabs-AI/flux-ip-adapter",
+ weight_name="ip_adapter.safetensors",
+ image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14"
+)
+pipe.set_ip_adapter_scale(1.0)
+
+image = pipe(
+ width=1024,
+ height=1024,
+ prompt="wearing sunglasses",
+ negative_prompt="",
+ true_cfg=4.0,
+ generator=torch.Generator().manual_seed(4444),
+ ip_adapter_image=image,
+).images[0]
+
+image.save('flux_ip_adapter_output.jpg')
+```
+
+
+
+
IP-Adapter examples with prompt "wearing sunglasses"
+
+
+## Optimize
+
+Flux is a very large model and requires ~50GB of RAM/VRAM to load all the modeling components. Enable some of the optimizations below to lower the memory requirements.
+
+### Group offloading
+
+[Group offloading](../../optimization/memory#group-offloading) lowers VRAM usage by offloading groups of internal layers rather than the whole model or weights. You need to use [`~hooks.apply_group_offloading`] on all the model components of a pipeline. The `offload_type` parameter allows you to toggle between block and leaf-level offloading. Setting it to `leaf_level` offloads the lowest leaf-level parameters to the CPU instead of offloading at the module-level.
+
+On CUDA devices that support asynchronous data streaming, set `use_stream=True` to overlap data transfer and computation to accelerate inference.
+
+> [!TIP]
+> It is possible to mix block and leaf-level offloading for different components in a pipeline.
+
+```py
+import torch
+from diffusers import FluxPipeline
+from diffusers.hooks import apply_group_offloading
+
+model_id = "black-forest-labs/FLUX.1-dev"
+dtype = torch.bfloat16
+pipe = FluxPipeline.from_pretrained(
+ model_id,
+ torch_dtype=dtype,
+)
+
+apply_group_offloading(
+ pipe.transformer,
+ offload_type="leaf_level",
+ offload_device=torch.device("cpu"),
+ onload_device=torch.device("cuda"),
+ use_stream=True,
+)
+apply_group_offloading(
+ pipe.text_encoder,
+ offload_device=torch.device("cpu"),
+ onload_device=torch.device("cuda"),
+ offload_type="leaf_level",
+ use_stream=True,
+)
+apply_group_offloading(
+ pipe.text_encoder_2,
+ offload_device=torch.device("cpu"),
+ onload_device=torch.device("cuda"),
+ offload_type="leaf_level",
+ use_stream=True,
+)
+apply_group_offloading(
+ pipe.vae,
+ offload_device=torch.device("cpu"),
+ onload_device=torch.device("cuda"),
+ offload_type="leaf_level",
+ use_stream=True,
+)
+
+prompt="A cat wearing sunglasses and working as a lifeguard at pool."
+
+generator = torch.Generator().manual_seed(181201)
+image = pipe(
+ prompt,
+ width=576,
+ height=1024,
+ num_inference_steps=30,
+ generator=generator
+).images[0]
+image
+```
+
+### Running FP16 inference
+
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
FP16 inference code:
@@ -105,6 +455,46 @@ out = pipe(
out.save("image.png")
```
+### Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`FluxPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="text_encoder_2",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ text_encoder_2=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon"
+image = pipeline(prompt, guidance_scale=3.5, height=768, width=1360, num_inference_steps=50).images[0]
+image.save("flux.png")
+```
+
## Single File Loading for the `FluxTransformer2DModel`
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
@@ -188,3 +578,27 @@ image.save("flux-fp8-dev.png")
[[autodoc]] FluxControlNetImg2ImgPipeline
- all
- __call__
+
+## FluxControlPipeline
+
+[[autodoc]] FluxControlPipeline
+ - all
+ - __call__
+
+## FluxControlImg2ImgPipeline
+
+[[autodoc]] FluxControlImg2ImgPipeline
+ - all
+ - __call__
+
+## FluxPriorReduxPipeline
+
+[[autodoc]] FluxPriorReduxPipeline
+ - all
+ - __call__
+
+## FluxFillPipeline
+
+[[autodoc]] FluxFillPipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md
new file mode 100644
index 000000000000..5d068c8b6ef8
--- /dev/null
+++ b/docs/source/en/api/pipelines/hunyuan_video.md
@@ -0,0 +1,95 @@
+
+
+# HunyuanVideo
+
+
+
+
+
+[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.
+
+*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/tencent/HunyuanVideo).*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+Recommendations for inference:
+- Both text encoders should be in `torch.float16`.
+- Transformer should be in `torch.bfloat16`.
+- VAE should be in `torch.float16`.
+- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
+- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
+- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
+
+## Available models
+
+The following models are available for the [`HunyuanVideoPipeline`](text-to-video) pipeline:
+
+| Model name | Description |
+|:---|:---|
+| [`hunyuanvideo-community/HunyuanVideo`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | Official HunyuanVideo (guidance-distilled). Performs best at multiple resolutions and frames. Performs best with `guidance_scale=6.0`, `true_cfg_scale=1.0` and without a negative prompt. |
+| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-T2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-T2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
+
+The following models are available for the image-to-video pipeline:
+
+| Model name | Description |
+|:---|:---|
+| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
+| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
+| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`HunyuanVideoPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
+from diffusers.utils import export_to_video
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+)
+
+pipeline = HunyuanVideoPipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo",
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "A cat walks on the grass, realistic style."
+video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
+export_to_video(video, "cat.mp4", fps=15)
+```
+
+## HunyuanVideoPipeline
+
+[[autodoc]] HunyuanVideoPipeline
+ - all
+ - __call__
+
+## HunyuanVideoPipelineOutput
+
+[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
diff --git a/docs/source/en/api/pipelines/hunyuandit.md b/docs/source/en/api/pipelines/hunyuandit.md
index 250533837ed0..d593259a09ed 100644
--- a/docs/source/en/api/pipelines/hunyuandit.md
+++ b/docs/source/en/api/pipelines/hunyuandit.md
@@ -30,7 +30,7 @@ HunyuanDiT has the following components:
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/i2vgenxl.md b/docs/source/en/api/pipelines/i2vgenxl.md
index cbb6be1176fd..3994f91d2cd0 100644
--- a/docs/source/en/api/pipelines/i2vgenxl.md
+++ b/docs/source/en/api/pipelines/i2vgenxl.md
@@ -22,7 +22,7 @@ The original codebase can be found [here](https://github.com/ali-vilab/i2vgen-xl
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
diff --git a/docs/source/en/api/pipelines/kandinsky.md b/docs/source/en/api/pipelines/kandinsky.md
index 9ea3cd4a1718..72cbf3fb474d 100644
--- a/docs/source/en/api/pipelines/kandinsky.md
+++ b/docs/source/en/api/pipelines/kandinsky.md
@@ -25,7 +25,7 @@ Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community)
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/kandinsky3.md b/docs/source/en/api/pipelines/kandinsky3.md
index 96123846af32..f4bea2b117d3 100644
--- a/docs/source/en/api/pipelines/kandinsky3.md
+++ b/docs/source/en/api/pipelines/kandinsky3.md
@@ -9,6 +9,10 @@ specific language governing permissions and limitations under the License.
# Kandinsky 3
+
+
+
+
Kandinsky 3 is created by [Vladimir Arkhipkin](https://github.com/oriBetelgeuse),[Anastasia Maltseva](https://github.com/NastyaMittseva),[Igor Pavlov](https://github.com/boomb0om),[Andrei Filatov](https://github.com/anvilarth),[Arseniy Shakhmatov](https://github.com/cene555),[Andrey Kuznetsov](https://github.com/kuznetsoffandrey),[Denis Dimitrov](https://github.com/denndimitrov), [Zein Shaheen](https://github.com/zeinsh)
The description from it's GitHub page:
@@ -32,7 +36,7 @@ Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community)
-Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/kandinsky_v22.md b/docs/source/en/api/pipelines/kandinsky_v22.md
index 13a6ca81d4a5..f097a085ef7f 100644
--- a/docs/source/en/api/pipelines/kandinsky_v22.md
+++ b/docs/source/en/api/pipelines/kandinsky_v22.md
@@ -25,7 +25,7 @@ Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community)
-Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/kolors.md b/docs/source/en/api/pipelines/kolors.md
index 367eb4a48548..3c08cf3ae300 100644
--- a/docs/source/en/api/pipelines/kolors.md
+++ b/docs/source/en/api/pipelines/kolors.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Kolors: Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis
+
+
+
+

Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](https://github.com/Kwai-Kolors/Kolors). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
diff --git a/docs/source/en/api/pipelines/latent_consistency_models.md b/docs/source/en/api/pipelines/latent_consistency_models.md
index 4d944510445c..a4d3bad0a7ac 100644
--- a/docs/source/en/api/pipelines/latent_consistency_models.md
+++ b/docs/source/en/api/pipelines/latent_consistency_models.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Latent Consistency Models
+
+
+
+
Latent Consistency Models (LCMs) were proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://huggingface.co/papers/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.
The abstract of the paper is as follows:
diff --git a/docs/source/en/api/pipelines/latent_diffusion.md b/docs/source/en/api/pipelines/latent_diffusion.md
index ab50faebbfba..e5cc7c1ab069 100644
--- a/docs/source/en/api/pipelines/latent_diffusion.md
+++ b/docs/source/en/api/pipelines/latent_diffusion.md
@@ -22,7 +22,7 @@ The original codebase can be found at [CompVis/latent-diffusion](https://github.
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/latte.md b/docs/source/en/api/pipelines/latte.md
index c2154d5d47c1..26e087442cdc 100644
--- a/docs/source/en/api/pipelines/latte.md
+++ b/docs/source/en/api/pipelines/latte.md
@@ -28,7 +28,7 @@ This pipeline was contributed by [maxin-cn](https://github.com/maxin-cn). The or
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
@@ -70,6 +70,47 @@ Without torch.compile(): Average inference time: 16.246 seconds.
With torch.compile(): Average inference time: 14.573 seconds.
```
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LattePipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LatteTransformer3DModel, LattePipeline
+from diffusers.utils import export_to_gif
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "maxin-cn/Latte-1",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = LatteTransformer3DModel.from_pretrained(
+ "maxin-cn/Latte-1",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = LattePipeline.from_pretrained(
+ "maxin-cn/Latte-1",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "A small cactus with a happy face in the Sahara desert."
+video = pipeline(prompt).frames[0]
+export_to_gif(video, "latte.gif")
+```
+
## LattePipeline
[[autodoc]] LattePipeline
diff --git a/docs/source/en/api/pipelines/ledits_pp.md b/docs/source/en/api/pipelines/ledits_pp.md
index 4d268a252edf..0dc4b536ab42 100644
--- a/docs/source/en/api/pipelines/ledits_pp.md
+++ b/docs/source/en/api/pipelines/ledits_pp.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# LEDITS++
+
+
+
+
LEDITS++ was proposed in [LEDITS++: Limitless Image Editing using Text-to-Image Models](https://huggingface.co/papers/2311.16711) by Manuel Brack, Felix Friedrich, Katharina Kornmeier, Linoy Tsaban, Patrick Schramowski, Kristian Kersting, Apolinário Passos.
The abstract from the paper is:
diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md
new file mode 100644
index 000000000000..4bc22c0f9f6c
--- /dev/null
+++ b/docs/source/en/api/pipelines/ltx_video.md
@@ -0,0 +1,207 @@
+
+
+# LTX Video
+
+
+
+
+
+[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+Available models:
+
+| Model name | Recommended dtype |
+|:-------------:|:-----------------:|
+| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
+| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
+
+Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
+
+## Loading Single Files
+
+Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
+
+```python
+import torch
+from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
+
+# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
+single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
+transformer = LTXVideoTransformer3DModel.from_single_file(
+ single_file_url, torch_dtype=torch.bfloat16
+)
+vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
+pipe = LTXImageToVideoPipeline.from_pretrained(
+ "Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16
+)
+
+# ... inference code ...
+```
+
+Alternatively, the pipeline can be used to load the weights with [`~FromSingleFileMixin.from_single_file`].
+
+```python
+import torch
+from diffusers import LTXImageToVideoPipeline
+from transformers import T5EncoderModel, T5Tokenizer
+
+single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
+text_encoder = T5EncoderModel.from_pretrained(
+ "Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16
+)
+tokenizer = T5Tokenizer.from_pretrained(
+ "Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16
+)
+pipe = LTXImageToVideoPipeline.from_single_file(
+ single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16
+)
+```
+
+Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
+
+```py
+import torch
+from diffusers.utils import export_to_video
+from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
+
+ckpt_path = (
+ "https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
+)
+transformer = LTXVideoTransformer3DModel.from_single_file(
+ ckpt_path,
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ torch_dtype=torch.bfloat16,
+)
+pipe = LTXPipeline.from_pretrained(
+ "Lightricks/LTX-Video",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+)
+pipe.enable_model_cpu_offload()
+
+prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
+negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=704,
+ height=480,
+ num_frames=161,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "output_gguf_ltx.mp4", fps=24)
+```
+
+Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
+
+
+
+Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
+
+```python
+import torch
+from diffusers import LTXPipeline
+from diffusers.utils import export_to_video
+
+pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
+negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=768,
+ height=512,
+ num_frames=161,
+ decode_timestep=0.03,
+ decode_noise_scale=0.025,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "output.mp4", fps=24)
+```
+
+Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LTXPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LTXVideoTransformer3DModel, LTXPipeline
+from diffusers.utils import export_to_video
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "Lightricks/LTX-Video",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = LTXVideoTransformer3DModel.from_pretrained(
+ "Lightricks/LTX-Video",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = LTXPipeline.from_pretrained(
+ "Lightricks/LTX-Video",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
+video = pipeline(prompt=prompt, num_frames=161, num_inference_steps=50).frames[0]
+export_to_video(video, "ship.mp4", fps=24)
+```
+
+## LTXPipeline
+
+[[autodoc]] LTXPipeline
+ - all
+ - __call__
+
+## LTXImageToVideoPipeline
+
+[[autodoc]] LTXImageToVideoPipeline
+ - all
+ - __call__
+
+## LTXConditionPipeline
+
+[[autodoc]] LTXConditionPipeline
+ - all
+ - __call__
+
+## LTXPipelineOutput
+
+[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
diff --git a/docs/source/en/api/pipelines/lumina.md b/docs/source/en/api/pipelines/lumina.md
index cc8aceefc1b1..ce5cf8b103cc 100644
--- a/docs/source/en/api/pipelines/lumina.md
+++ b/docs/source/en/api/pipelines/lumina.md
@@ -47,7 +47,7 @@ This pipeline was contributed by [PommesPeter](https://github.com/PommesPeter).
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
@@ -58,10 +58,10 @@ Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fa
First, load the pipeline:
```python
-from diffusers import LuminaText2ImgPipeline
+from diffusers import LuminaPipeline
import torch
-pipeline = LuminaText2ImgPipeline.from_pretrained(
+pipeline = LuminaPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
).to("cuda")
```
@@ -82,9 +82,49 @@ pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fu
image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures").images[0]
```
-## LuminaText2ImgPipeline
+## Quantization
-[[autodoc]] LuminaText2ImgPipeline
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "Alpha-VLLM/Lumina-Next-SFT-diffusers",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = Transformer2DModel.from_pretrained(
+ "Alpha-VLLM/Lumina-Next-SFT-diffusers",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = LuminaPipeline.from_pretrained(
+ "Alpha-VLLM/Lumina-Next-SFT-diffusers",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon"
+image = pipeline(prompt).images[0]
+image.save("lumina.png")
+```
+
+## LuminaPipeline
+
+[[autodoc]] LuminaPipeline
- all
- __call__
diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md
new file mode 100644
index 000000000000..57f0e8e2105d
--- /dev/null
+++ b/docs/source/en/api/pipelines/lumina2.md
@@ -0,0 +1,87 @@
+
+
+# Lumina2
+
+
+
+
+
+[Lumina Image 2.0: A Unified and Efficient Image Generative Model](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) is a 2 billion parameter flow-based diffusion transformer capable of generating diverse images from text descriptions.
+
+The abstract from the paper is:
+
+*We introduce Lumina-Image 2.0, an advanced text-to-image model that surpasses previous state-of-the-art methods across multiple benchmarks, while also shedding light on its potential to evolve into a generalist vision intelligence model. Lumina-Image 2.0 exhibits three key properties: (1) Unification – it adopts a unified architecture that treats text and image tokens as a joint sequence, enabling natural cross-modal interactions and facilitating task expansion. Besides, since high-quality captioners can provide semantically better-aligned text-image training pairs, we introduce a unified captioning system, UniCaptioner, which generates comprehensive and precise captions for the model. This not only accelerates model convergence but also enhances prompt adherence, variable-length prompt handling, and task generalization via prompt templates. (2) Efficiency – to improve the efficiency of the unified architecture, we develop a set of optimization techniques that improve semantic learning and fine-grained texture generation during training while incorporating inference-time acceleration strategies without compromising image quality. (3) Transparency – we open-source all training details, code, and models to ensure full reproducibility, aiming to bridge the gap between well-resourced closed-source research teams and independent developers.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## Using Single File loading with Lumina Image 2.0
+
+Single file loading for Lumina Image 2.0 is available for the `Lumina2Transformer2DModel`
+
+```python
+import torch
+from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
+
+ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
+transformer = Lumina2Transformer2DModel.from_single_file(
+ ckpt_path, torch_dtype=torch.bfloat16
+)
+
+pipe = Lumina2Pipeline.from_pretrained(
+ "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
+)
+pipe.enable_model_cpu_offload()
+image = pipe(
+ "a cat holding a sign that says hello",
+ generator=torch.Generator("cpu").manual_seed(0),
+).images[0]
+image.save("lumina-single-file.png")
+
+```
+
+## Using GGUF Quantized Checkpoints with Lumina Image 2.0
+
+GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`
+
+```python
+from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig
+
+ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
+transformer = Lumina2Transformer2DModel.from_single_file(
+ ckpt_path,
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ torch_dtype=torch.bfloat16,
+)
+
+pipe = Lumina2Pipeline.from_pretrained(
+ "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
+)
+pipe.enable_model_cpu_offload()
+image = pipe(
+ "a cat holding a sign that says hello",
+ generator=torch.Generator("cpu").manual_seed(0),
+).images[0]
+image.save("lumina-gguf.png")
+```
+
+## Lumina2Pipeline
+
+[[autodoc]] Lumina2Pipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/pipelines/marigold.md b/docs/source/en/api/pipelines/marigold.md
index 374947ce95ab..e9ca0df067ba 100644
--- a/docs/source/en/api/pipelines/marigold.md
+++ b/docs/source/en/api/pipelines/marigold.md
@@ -1,4 +1,6 @@
-
-# Marigold Pipelines for Computer Vision Tasks
+# Marigold Computer Vision

-Marigold was proposed in [Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation](https://huggingface.co/papers/2312.02145), a CVPR 2024 Oral paper by [Bingxin Ke](http://www.kebingxin.com/), [Anton Obukhov](https://www.obukhov.ai/), [Shengyu Huang](https://shengyuh.github.io/), [Nando Metzger](https://nandometzger.github.io/), [Rodrigo Caye Daudt](https://rcdaudt.github.io/), and [Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en).
-The idea is to repurpose the rich generative prior of Text-to-Image Latent Diffusion Models (LDMs) for traditional computer vision tasks.
-Initially, this idea was explored to fine-tune Stable Diffusion for Monocular Depth Estimation, as shown in the teaser above.
-Later,
-- [Tianfu Wang](https://tianfwang.github.io/) trained the first Latent Consistency Model (LCM) of Marigold, which unlocked fast single-step inference;
-- [Kevin Qu](https://www.linkedin.com/in/kevin-qu-b3417621b/?locale=en_US) extended the approach to Surface Normals Estimation;
-- [Anton Obukhov](https://www.obukhov.ai/) contributed the pipelines and documentation into diffusers (enabled and supported by [YiYi Xu](https://yiyixuxu.github.io/) and [Sayak Paul](https://sayak.dev/)).
+Marigold was proposed in
+[Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation](https://huggingface.co/papers/2312.02145),
+a CVPR 2024 Oral paper by
+[Bingxin Ke](http://www.kebingxin.com/),
+[Anton Obukhov](https://www.obukhov.ai/),
+[Shengyu Huang](https://shengyuh.github.io/),
+[Nando Metzger](https://nandometzger.github.io/),
+[Rodrigo Caye Daudt](https://rcdaudt.github.io/), and
+[Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en).
+The core idea is to **repurpose the generative prior of Text-to-Image Latent Diffusion Models (LDMs) for traditional
+computer vision tasks**.
+This approach was explored by fine-tuning Stable Diffusion for **Monocular Depth Estimation**, as demonstrated in the
+teaser above.
+
+Marigold was later extended in the follow-up paper,
+[Marigold: Affordable Adaptation of Diffusion-Based Image Generators for Image Analysis](https://huggingface.co/papers/2312.02145),
+authored by
+[Bingxin Ke](http://www.kebingxin.com/),
+[Kevin Qu](https://www.linkedin.com/in/kevin-qu-b3417621b/?locale=en_US),
+[Tianfu Wang](https://tianfwang.github.io/),
+[Nando Metzger](https://nandometzger.github.io/),
+[Shengyu Huang](https://shengyuh.github.io/),
+[Bo Li](https://www.linkedin.com/in/bobboli0202/),
+[Anton Obukhov](https://www.obukhov.ai/), and
+[Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en).
+This work expanded Marigold to support new modalities such as **Surface Normals** and **Intrinsic Image Decomposition**
+(IID), introduced a training protocol for **Latent Consistency Models** (LCM), and demonstrated **High-Resolution** (HR)
+processing capability.
-The abstract from the paper is:
+
-*Monocular depth estimation is a fundamental computer vision task. Recovering 3D depth from a single image is geometrically ill-posed and requires scene understanding, so it is not surprising that the rise of deep learning has led to a breakthrough. The impressive progress of monocular depth estimators has mirrored the growth in model capacity, from relatively modest CNNs to large Transformer architectures. Still, monocular depth estimators tend to struggle when presented with images with unfamiliar content and layout, since their knowledge of the visual world is restricted by the data seen during training, and challenged by zero-shot generalization to new domains. This motivates us to explore whether the extensive priors captured in recent generative diffusion models can enable better, more generalizable depth estimation. We introduce Marigold, a method for affine-invariant monocular depth estimation that is derived from Stable Diffusion and retains its rich prior knowledge. The estimator can be fine-tuned in a couple of days on a single GPU using only synthetic training data. It delivers state-of-the-art performance across a wide range of datasets, including over 20% performance gains in specific cases. Project page: https://marigoldmonodepth.github.io.*
+The early Marigold models (`v1-0` and earlier) were optimized for best results with at least 10 inference steps.
+LCM models were later developed to enable high-quality inference in just 1 to 4 steps.
+Marigold models `v1-1` and later use the DDIM scheduler to achieve optimal
+results in as few as 1 to 4 steps.
-## Available Pipelines
+
-Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image.
-Currently, the following tasks are implemented:
+## Available Pipelines
-| Pipeline | Predicted Modalities | Demos |
-|---------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------:|
-| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-lcm), [Slow Original Demo (DDIM)](https://huggingface.co/spaces/prs-eth/marigold) |
-| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-normals-lcm) |
+Each pipeline is tailored for a specific computer vision task, processing an input RGB image and generating a
+corresponding prediction.
+Currently, the following computer vision tasks are implemented:
+| Pipeline | Recommended Model Checkpoints | Spaces (Interactive Apps) | Predicted Modalities |
+|---------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | [Depth Estimation](https://huggingface.co/spaces/prs-eth/marigold) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) |
+| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [prs-eth/marigold-normals-v1-1](https://huggingface.co/prs-eth/marigold-normals-v1-1) | [Surface Normals Estimation](https://huggingface.co/spaces/prs-eth/marigold-normals) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) |
+| [MarigoldIntrinsicsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py) | [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1), [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | [Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid) | [Albedo](https://en.wikipedia.org/wiki/Albedo), [Materials](https://www.n.aiq3d.com/wiki/roughnessmetalnessao-map), [Lighting](https://en.wikipedia.org/wiki/Diffuse_reflection) |
## Available Checkpoints
-The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization.
+All original checkpoints are available under the [PRS-ETH](https://huggingface.co/prs-eth/) organization on Hugging Face.
+They are designed for use with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold), which can also be used to train
+new model checkpoints.
+The following is a summary of the recommended checkpoints, all of which produce reliable results with 1 to 4 steps.
+
+| Checkpoint | Modality | Comment |
+|-----------------------------------------------------------------------------------------------------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | Depth | Affine-invariant depth prediction assigns each pixel a value between 0 (near plane) and 1 (far plane), with both planes determined by the model during inference. |
+| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | The surface normals predictions are unit-length 3D vectors in the screen space camera, with values in the range from -1 to 1. |
+| [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1) | Intrinsics | InteriorVerse decomposition is comprised of Albedo and two BRDF material properties: Roughness and Metallicity. |
+| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | Intrinsics | HyperSim decomposition of an image  \\(I\\)  is comprised of Albedo  \\(A\\), Diffuse shading  \\(S\\), and Non-diffuse residual  \\(R\\):  \\(I = A*S+R\\). |
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff
+between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to
+efficiently load the same components into multiple pipelines.
+Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section
+[here](../../using-diffusers/svd#reduce-memory-usage).
-Marigold pipelines were designed and tested only with `DDIMScheduler` and `LCMScheduler`.
-Depending on the scheduler, the number of inference steps required to get reliable predictions varies, and there is no universal value that works best across schedulers.
-Because of that, the default value of `num_inference_steps` in the `__call__` method of the pipeline is set to `None` (see the API reference).
-Unless set explicitly, its value will be taken from the checkpoint configuration `model_index.json`.
-This is done to ensure high-quality predictions when calling the pipeline with just the `image` argument.
+Marigold pipelines were designed and tested with the scheduler embedded in the model checkpoint.
+The optimal number of inference steps varies by scheduler, with no universal value that works best across all cases.
+To accommodate this, the `num_inference_steps` parameter in the pipeline's `__call__` method defaults to `None` (see the
+API reference).
+Unless set explicitly, it inherits the value from the `default_denoising_steps` field in the checkpoint configuration
+file (`model_index.json`).
+This ensures high-quality predictions when invoking the pipeline with only the `image` argument.
-See also Marigold [usage examples](marigold_usage).
+See also Marigold [usage examples](../../using-diffusers/marigold_usage).
+
+## Marigold Depth Prediction API
-## MarigoldDepthPipeline
[[autodoc]] MarigoldDepthPipeline
- - all
- __call__
-## MarigoldNormalsPipeline
+[[autodoc]] pipelines.marigold.pipeline_marigold_depth.MarigoldDepthOutput
+
+[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth
+
+## Marigold Normals Estimation API
[[autodoc]] MarigoldNormalsPipeline
- - all
- __call__
-## MarigoldDepthOutput
-[[autodoc]] pipelines.marigold.pipeline_marigold_depth.MarigoldDepthOutput
+[[autodoc]] pipelines.marigold.pipeline_marigold_normals.MarigoldNormalsOutput
+
+[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals
+
+## Marigold Intrinsic Image Decomposition API
+
+[[autodoc]] MarigoldIntrinsicsPipeline
+ - __call__
+
+[[autodoc]] pipelines.marigold.pipeline_marigold_intrinsics.MarigoldIntrinsicsOutput
-## MarigoldNormalsOutput
-[[autodoc]] pipelines.marigold.pipeline_marigold_normals.MarigoldNormalsOutput
\ No newline at end of file
+[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_intrinsics
diff --git a/docs/source/en/api/pipelines/mochi.md b/docs/source/en/api/pipelines/mochi.md
new file mode 100644
index 000000000000..ccbaf40af8f8
--- /dev/null
+++ b/docs/source/en/api/pipelines/mochi.md
@@ -0,0 +1,279 @@
+
+
+# Mochi 1 Preview
+
+
+
+
+
+> [!TIP]
+> Only a research preview of the model weights is available at the moment.
+
+[Mochi 1](https://huggingface.co/genmo/mochi-1-preview) is a video generation model by Genmo with a strong focus on prompt adherence and motion quality. The model features a 10B parameter Asmmetric Diffusion Transformer (AsymmDiT) architecture, and uses non-square QKV and output projection layers to reduce inference memory requirements. A single T5-XXL model is used to encode prompts.
+
+*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.*
+
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`MochiPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, MochiTransformer3DModel, MochiPipeline
+from diffusers.utils import export_to_video
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "genmo/mochi-1-preview",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = MochiTransformer3DModel.from_pretrained(
+ "genmo/mochi-1-preview",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = MochiPipeline.from_pretrained(
+ "genmo/mochi-1-preview",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+video = pipeline(
+ "Close-up of a cats eye, with the galaxy reflected in the cats eye. Ultra high resolution 4k.",
+ num_inference_steps=28,
+ guidance_scale=3.5
+).frames[0]
+export_to_video(video, "cat.mp4")
+```
+
+## Generating videos with Mochi-1 Preview
+
+The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.
+
+```python
+import torch
+from diffusers import MochiPipeline
+from diffusers.utils import export_to_video
+
+pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
+
+# Enable memory savings
+pipe.enable_model_cpu_offload()
+pipe.enable_vae_tiling()
+
+prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
+
+with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
+ frames = pipe(prompt, num_frames=85).frames[0]
+
+export_to_video(frames, "mochi.mp4", fps=30)
+```
+
+## Using a lower precision variant to save memory
+
+The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.
+
+```python
+import torch
+from diffusers import MochiPipeline
+from diffusers.utils import export_to_video
+
+pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
+
+# Enable memory savings
+pipe.enable_model_cpu_offload()
+pipe.enable_vae_tiling()
+
+prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
+frames = pipe(prompt, num_frames=85).frames[0]
+
+export_to_video(frames, "mochi.mp4", fps=30)
+```
+
+## Reproducing the results from the Genmo Mochi repo
+
+The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the original implementation, please refer to the following example.
+
+
+The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
+
+When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
+
+
+
+Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
+
+
+```python
+import torch
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+from diffusers import MochiPipeline
+from diffusers.utils import export_to_video
+from diffusers.video_processor import VideoProcessor
+
+pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
+pipe.enable_vae_tiling()
+pipe.enable_model_cpu_offload()
+
+prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape."
+
+with torch.no_grad():
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
+ pipe.encode_prompt(prompt=prompt)
+ )
+
+with torch.autocast("cuda", torch.bfloat16):
+ with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
+ frames = pipe(
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_embeds=negative_prompt_embeds,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ guidance_scale=4.5,
+ num_inference_steps=64,
+ height=480,
+ width=848,
+ num_frames=163,
+ generator=torch.Generator("cuda").manual_seed(0),
+ output_type="latent",
+ return_dict=False,
+ )[0]
+
+video_processor = VideoProcessor(vae_scale_factor=8)
+has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
+has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
+if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
+ )
+ latents_std = (
+ torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
+ )
+ frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
+else:
+ frames = frames / pipe.vae.config.scaling_factor
+
+with torch.no_grad():
+ video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
+
+video = video_processor.postprocess_video(video)[0]
+export_to_video(video, "mochi.mp4", fps=30)
+```
+
+## Running inference with multiple GPUs
+
+It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.
+
+```python
+import torch
+from diffusers import MochiPipeline, MochiTransformer3DModel
+from diffusers.utils import export_to_video
+
+model_id = "genmo/mochi-1-preview"
+transformer = MochiTransformer3DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ device_map="auto",
+ max_memory={0: "24GB", 1: "24GB"}
+)
+
+pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
+pipe.enable_model_cpu_offload()
+pipe.enable_vae_tiling()
+
+with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
+ frames = pipe(
+ prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
+ negative_prompt="",
+ height=480,
+ width=848,
+ num_frames=85,
+ num_inference_steps=50,
+ guidance_scale=4.5,
+ num_videos_per_prompt=1,
+ generator=torch.Generator(device="cuda").manual_seed(0),
+ max_sequence_length=256,
+ output_type="pil",
+ ).frames[0]
+
+export_to_video(frames, "output.mp4", fps=30)
+```
+
+## Using single file loading with the Mochi Transformer
+
+You can use `from_single_file` to load the Mochi transformer in its original format.
+
+
+Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
+
+
+```python
+import torch
+from diffusers import MochiPipeline, MochiTransformer3DModel
+from diffusers.utils import export_to_video
+
+model_id = "genmo/mochi-1-preview"
+
+ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors"
+
+transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)
+
+pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
+pipe.enable_model_cpu_offload()
+pipe.enable_vae_tiling()
+
+with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
+ frames = pipe(
+ prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
+ negative_prompt="",
+ height=480,
+ width=848,
+ num_frames=85,
+ num_inference_steps=50,
+ guidance_scale=4.5,
+ num_videos_per_prompt=1,
+ generator=torch.Generator(device="cuda").manual_seed(0),
+ max_sequence_length=256,
+ output_type="pil",
+ ).frames[0]
+
+export_to_video(frames, "output.mp4", fps=30)
+```
+
+## MochiPipeline
+
+[[autodoc]] MochiPipeline
+ - all
+ - __call__
+
+## MochiPipelineOutput
+
+[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput
diff --git a/docs/source/en/api/pipelines/musicldm.md b/docs/source/en/api/pipelines/musicldm.md
index 3ffb6541405d..412e8e41c2ca 100644
--- a/docs/source/en/api/pipelines/musicldm.md
+++ b/docs/source/en/api/pipelines/musicldm.md
@@ -42,7 +42,7 @@ During inference:
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md
new file mode 100644
index 000000000000..114e3753e710
--- /dev/null
+++ b/docs/source/en/api/pipelines/omnigen.md
@@ -0,0 +1,80 @@
+
+
+# OmniGen
+
+[OmniGen: Unified Image Generation](https://arxiv.org/pdf/2409.11340) from BAAI, by Shitao Xiao, Yueze Wang, Junjie Zhou, Huaying Yuan, Xingrun Xing, Ruiran Yan, Chaofan Li, Shuting Wang, Tiejun Huang, Zheng Liu.
+
+The abstract from the paper is:
+
+*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
+
+## Inference
+
+First, load the pipeline:
+
+```python
+import torch
+from diffusers import OmniGenPipeline
+
+pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+```
+
+For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
+You can try setting the `height` and `width` parameters to generate images with different size.
+
+```python
+prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
+image = pipe(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=3,
+ generator=torch.Generator(device="cpu").manual_seed(111),
+).images[0]
+image.save("output.png")
+```
+
+OmniGen supports multimodal inputs.
+When the input includes an image, you need to add a placeholder ` <|image_1|>` in the text prompt to represent the image.
+It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
+
+```python
+prompt=" <|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
+image.save("output.png")
+```
+
+## OmniGenPipeline
+
+[[autodoc]] OmniGenPipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md
index 02c77d197e34..6a8e82a692e0 100644
--- a/docs/source/en/api/pipelines/overview.md
+++ b/docs/source/en/api/pipelines/overview.md
@@ -54,7 +54,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [DiT](dit) | text2image |
| [Flux](flux) | text2image |
| [Hunyuan-DiT](hunyuandit) | text2image |
-| [I2VGen-XL](i2vgenxl) | text2video |
+| [I2VGen-XL](i2vgenxl) | image2video |
| [InstructPix2Pix](pix2pix) | image editing |
| [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation |
| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
@@ -65,7 +65,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Latte](latte) | text2image |
| [LEDITS++](ledits_pp) | image editing |
| [Lumina-T2X](lumina) | text2image |
-| [Marigold](marigold) | depth |
+| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
| [MultiDiffusion](panorama) | text2image |
| [MusicLDM](musicldm) | text2audio |
| [PAG](pag) | text2image |
diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md
index cc6d075f457f..64aefdf7e78f 100644
--- a/docs/source/en/api/pipelines/pag.md
+++ b/docs/source/en/api/pipelines/pag.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Perturbed-Attention Guidance
+
+
+
+
[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules.
PAG was introduced in [Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance](https://huggingface.co/papers/2403.17377) by Donghoon Ahn, Hyoungwon Cho, Jaewon Min, Wooseok Jang, Jungwoo Kim, SeonHwa Kim, Hyun Hee Park, Kyong Hwan Jin and Seungryong Kim.
@@ -48,6 +52,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__
+## StableDiffusionPAGInpaintPipeline
+[[autodoc]] StableDiffusionPAGInpaintPipeline
+ - all
+ - __call__
+
## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
@@ -96,6 +105,10 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__
+## StableDiffusion3PAGImg2ImgPipeline
+[[autodoc]] StableDiffusion3PAGImg2ImgPipeline
+ - all
+ - __call__
## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
diff --git a/docs/source/en/api/pipelines/paint_by_example.md b/docs/source/en/api/pipelines/paint_by_example.md
index effd608873fd..75360596d676 100644
--- a/docs/source/en/api/pipelines/paint_by_example.md
+++ b/docs/source/en/api/pipelines/paint_by_example.md
@@ -26,7 +26,7 @@ Paint by Example is supported by the official [Fantasy-Studio/Paint-by-Example](
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/panorama.md b/docs/source/en/api/pipelines/panorama.md
index b34008ad830f..cbd5aaf815db 100644
--- a/docs/source/en/api/pipelines/panorama.md
+++ b/docs/source/en/api/pipelines/panorama.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# MultiDiffusion
+
+
+
+
[MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation](https://huggingface.co/papers/2302.08113) is by Omer Bar-Tal, Lior Yariv, Yaron Lipman, and Tali Dekel.
The abstract from the paper is:
@@ -37,7 +41,7 @@ But with circular padding, the right and the left parts are matching (`circular_
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md
index 8ba78252c99b..86c0e8eb191a 100644
--- a/docs/source/en/api/pipelines/pia.md
+++ b/docs/source/en/api/pipelines/pia.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Image-to-Video Generation with PIA (Personalized Image Animator)
+
+
+
+
## Overview
[PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models](https://arxiv.org/abs/2312.13964) by Yiming Zhang, Zhening Xing, Yanhong Zeng, Youqing Fang, Kai Chen
diff --git a/docs/source/en/api/pipelines/pix2pix.md b/docs/source/en/api/pipelines/pix2pix.md
index 52767a90b214..d0b3bf32b823 100644
--- a/docs/source/en/api/pipelines/pix2pix.md
+++ b/docs/source/en/api/pipelines/pix2pix.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# InstructPix2Pix
+
+
+
+
[InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/papers/2211.09800) is by Tim Brooks, Aleksander Holynski and Alexei A. Efros.
The abstract from the paper is:
@@ -22,7 +26,7 @@ You can find additional information about InstructPix2Pix on the [project page](
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/pixart.md b/docs/source/en/api/pipelines/pixart.md
index b2bef501b237..d4e268b81d49 100644
--- a/docs/source/en/api/pipelines/pixart.md
+++ b/docs/source/en/api/pipelines/pixart.md
@@ -31,7 +31,7 @@ Some notes about this pipeline:
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/sana.md b/docs/source/en/api/pipelines/sana.md
new file mode 100644
index 000000000000..3702b2771974
--- /dev/null
+++ b/docs/source/en/api/pipelines/sana.md
@@ -0,0 +1,111 @@
+
+
+# SanaPipeline
+
+
+
+
+
+[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
+
+The abstract from the paper is:
+
+*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj) and [chenjy2003](https://github.com/chenjy2003). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model).
+
+Available models:
+
+| Model | Recommended dtype |
+|:-----:|:-----------------:|
+| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
+| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
+| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
+| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
+| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
+| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
+| [`Efficient-Large-Model/Sana_600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_512px_diffusers) | `torch.float16` |
+
+Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) collection for more information.
+
+Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
+
+
+
+Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `"fp16"` for models with recommended dtype as `torch.float16`, and `"bf16"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained).
+
+
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = AutoModel.from_pretrained(
+ "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = SanaTransformer2DModel.from_pretrained(
+ "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = SanaPipeline.from_pretrained(
+ "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon"
+image = pipeline(prompt).images[0]
+image.save("sana.png")
+```
+
+## SanaPipeline
+
+[[autodoc]] SanaPipeline
+ - all
+ - __call__
+
+## SanaPAGPipeline
+
+[[autodoc]] SanaPAGPipeline
+ - all
+ - __call__
+
+## SanaPipelineOutput
+
+[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md
new file mode 100644
index 000000000000..8db4576cf579
--- /dev/null
+++ b/docs/source/en/api/pipelines/sana_sprint.md
@@ -0,0 +1,100 @@
+
+
+# SanaSprintPipeline
+
+
+
+
+
+[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
+
+The abstract from the paper is:
+
+*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
+
+Available models:
+
+| Model | Recommended dtype |
+|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
+| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
+| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
+
+Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
+
+Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
+
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = AutoModel.from_pretrained(
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = SanaTransformer2DModel.from_pretrained(
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+)
+
+pipeline = SanaSprintPipeline.from_pretrained(
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.bfloat16,
+ device_map="balanced",
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon"
+image = pipeline(prompt).images[0]
+image.save("sana.png")
+```
+
+## Setting `max_timesteps`
+
+Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
+
+## SanaSprintPipeline
+
+[[autodoc]] SanaSprintPipeline
+ - all
+ - __call__
+
+
+## SanaPipelineOutput
+
+[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
diff --git a/docs/source/en/api/pipelines/self_attention_guidance.md b/docs/source/en/api/pipelines/self_attention_guidance.md
index e56aae2a775b..d656ce93f104 100644
--- a/docs/source/en/api/pipelines/self_attention_guidance.md
+++ b/docs/source/en/api/pipelines/self_attention_guidance.md
@@ -22,7 +22,7 @@ You can find additional information about Self-Attention Guidance on the [projec
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.md b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
index 19a0a8116989..b9aacd3518d8 100644
--- a/docs/source/en/api/pipelines/semantic_stable_diffusion.md
+++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
@@ -21,7 +21,7 @@ The abstract from the paper is:
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/shap_e.md b/docs/source/en/api/pipelines/shap_e.md
index 9f9155c79e89..3c1f939c1fce 100644
--- a/docs/source/en/api/pipelines/shap_e.md
+++ b/docs/source/en/api/pipelines/shap_e.md
@@ -19,7 +19,7 @@ The original codebase can be found at [openai/shap-e](https://github.com/openai/
-See the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+See the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/stable_audio.md b/docs/source/en/api/pipelines/stable_audio.md
index a6d34a0697d5..1acb72b3968a 100644
--- a/docs/source/en/api/pipelines/stable_audio.md
+++ b/docs/source/en/api/pipelines/stable_audio.md
@@ -35,6 +35,57 @@ During inference:
* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`StableAudioPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, StableAudioDiTModel, StableAudioPipeline
+from diffusers.utils import export_to_video
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "stabilityai/stable-audio-open-1.0",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = StableAudioDiTModel.from_pretrained(
+ "stabilityai/stable-audio-open-1.0",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = StableAudioPipeline.from_pretrained(
+ "stabilityai/stable-audio-open-1.0",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "The sound of a hammer hitting a wooden surface."
+negative_prompt = "Low quality."
+audio = pipeline(
+ prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=200,
+ audio_end_in_s=10.0,
+ num_waveforms_per_prompt=3,
+ generator=generator,
+).audios
+
+output = audio[0].T.float().cpu().numpy()
+sf.write("hammer.wav", output, pipeline.vae.sampling_rate)
+```
+
## StableAudioPipeline
[[autodoc]] StableAudioPipeline
diff --git a/docs/source/en/api/pipelines/stable_diffusion/depth2img.md b/docs/source/en/api/pipelines/stable_diffusion/depth2img.md
index 84dae80498a3..0cf58fe1d2fb 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/depth2img.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/depth2img.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Depth-to-image
+
+
+
+
The Stable Diffusion model can also infer depth based on an image using [MiDaS](https://github.com/isl-org/MiDaS). This allows you to pass a text prompt and an initial image to condition the generation of new images as well as a `depth_map` to preserve the image structure.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/img2img.md b/docs/source/en/api/pipelines/stable_diffusion/img2img.md
index 1a62a5a48ff0..f5779de1ee62 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/img2img.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/img2img.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Image-to-image
+
+
+
+
The Stable Diffusion model can also be applied to image-to-image generation by passing a text prompt and an initial image to condition the generation of new images.
The [`StableDiffusionImg2ImgPipeline`] uses the diffusion-denoising mechanism proposed in [SDEdit: Guided Image Synthesis and Editing with Stochastic Differential Equations](https://huggingface.co/papers/2108.01073) by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan Zhu, Stefano Ermon.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/inpaint.md b/docs/source/en/api/pipelines/stable_diffusion/inpaint.md
index ef605cfe8b90..f75c9ca3dd0b 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/inpaint.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/inpaint.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Inpainting
+
+
+
+
The Stable Diffusion model can also be applied to inpainting which lets you edit specific parts of an image by providing a mask and a text prompt using Stable Diffusion.
## Tips
diff --git a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
index 23830462c20b..f2c6ae8f1ddb 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Text-to-(RGB, depth)
+
+
+
+
LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://huggingface.co/papers/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, and Vasudev Lal. LDM3D generates an image and a depth map from a given text prompt unlike the existing text-to-image diffusion models such as [Stable Diffusion](./overview) which only generates an image. With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.
Two checkpoints are available for use:
diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.md b/docs/source/en/api/pipelines/stable_diffusion/overview.md
index 5087d1fdd43a..25984091215c 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/overview.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/overview.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Stable Diffusion pipelines
+
+
+
+
Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). Latent diffusion applies the diffusion process over a lower dimensional latent space to reduce memory and compute complexity. This specific type of diffusion model was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.
Stable Diffusion is trained on 512x512 images from a subset of the LAION-5B dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
index fd026f07c923..4ba577795b0d 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Stable Diffusion 3
+
+
+
+
Stable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206.pdf) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.
The abstract from the paper is:
@@ -59,9 +63,76 @@ image.save("sd3_hello_world.png")
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
+## Image Prompting with IP-Adapters
+
+An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:
+
+- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.
+- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.
+- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection.
+
+IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.
+
+```python
+import torch
+from PIL import Image
+
+from diffusers import StableDiffusion3Pipeline
+from transformers import SiglipVisionModel, SiglipImageProcessor
+
+image_encoder_id = "google/siglip-so400m-patch14-384"
+ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
+
+feature_extractor = SiglipImageProcessor.from_pretrained(
+ image_encoder_id,
+ torch_dtype=torch.float16
+)
+image_encoder = SiglipVisionModel.from_pretrained(
+ image_encoder_id,
+ torch_dtype=torch.float16
+).to( "cuda")
+
+pipe = StableDiffusion3Pipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-large",
+ torch_dtype=torch.float16,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+).to("cuda")
+
+pipe.load_ip_adapter(ip_adapter_id)
+pipe.set_ip_adapter_scale(0.6)
+
+ref_img = Image.open("image.jpg").convert('RGB')
+
+image = pipe(
+ width=1024,
+ height=1024,
+ prompt="a cat",
+ negative_prompt="lowres, low quality, worst quality",
+ num_inference_steps=24,
+ guidance_scale=5.0,
+ ip_adapter_image=ref_img
+).images[0]
+
+image.save("result.jpg")
+```
+
+
+
+
IP-Adapter examples with prompt "a cat"
+
+
+
+
+
+Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
+
+
+
+
## Memory Optimisations for SD3
-SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
+SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
### Running Inference with Model Offloading
@@ -201,6 +272,46 @@ image.save("sd3_hello_world.png")
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`StableDiffusion3Pipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SD3Transformer2DModel, StableDiffusion3Pipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = T5EncoderModel.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-large",
+ subfolder="text_encoder_3",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-large",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = StableDiffusion3Pipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-large",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon"
+image = pipeline(prompt, num_inference_steps=28, guidance_scale=7.0).images[0]
+image.save("sd3.png")
+```
+
## Using Long Prompts with the T5 Text Encoder
By default, the T5 Text Encoder prompt uses a maximum sequence length of `256`. This can be adjusted by setting the `max_sequence_length` to accept fewer or more tokens. Keep in mind that longer sequences require additional resources and result in longer generation times, such as during batch inference.
@@ -313,6 +424,26 @@ image = pipe("a picture of a cat holding a sign that says hello world").images[0
image.save('sd3-single-file-t5-fp8.png')
```
+### Loading the single file checkpoint for the Stable Diffusion 3.5 Transformer Model
+
+```python
+import torch
+from diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline
+
+transformer = SD3Transformer2DModel.from_single_file(
+ "https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors",
+ torch_dtype=torch.bfloat16,
+)
+pipe = StableDiffusion3Pipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-large",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+)
+pipe.enable_model_cpu_offload()
+image = pipe("a cat holding a sign that says hello world").images[0]
+image.save("sd35.png")
+```
+
## StableDiffusion3Pipeline
[[autodoc]] StableDiffusion3Pipeline
diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md
index c5433c0783ba..485ee7d7fc28 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Stable Diffusion XL
+
+
+
+
Stable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.
The abstract from the paper is:
diff --git a/docs/source/en/api/pipelines/stable_diffusion/text2img.md b/docs/source/en/api/pipelines/stable_diffusion/text2img.md
index 86f3090fe9fd..c7ac145712cb 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/text2img.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/text2img.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Text-to-image
+
+
+
+
The Stable Diffusion model was created by researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [Runway](https://github.com/runwayml), and [LAION](https://laion.ai/). The [`StableDiffusionPipeline`] is capable of generating photorealistic images given any text input. It's trained on 512x512 images from a subset of the LAION-5B dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs. Latent diffusion is the research on top of which Stable Diffusion was built. It was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.
The abstract from the paper is:
diff --git a/docs/source/en/api/pipelines/stable_diffusion/upscale.md b/docs/source/en/api/pipelines/stable_diffusion/upscale.md
index b188c29bff6b..53a95d501e34 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/upscale.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/upscale.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Super-resolution
+
+
+
+
The Stable Diffusion upscaler diffusion model was created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), and [LAION](https://laion.ai/). It is used to enhance the resolution of input images by a factor of 4.
diff --git a/docs/source/en/api/pipelines/stable_unclip.md b/docs/source/en/api/pipelines/stable_unclip.md
index 3067ba91f752..9c281b28ab4d 100644
--- a/docs/source/en/api/pipelines/stable_unclip.md
+++ b/docs/source/en/api/pipelines/stable_unclip.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Stable unCLIP
+
+
+
+
Stable unCLIP checkpoints are finetuned from [Stable Diffusion 2.1](./stable_diffusion/stable_diffusion_2) checkpoints to condition on CLIP image embeddings.
Stable unCLIP still conditions on text embeddings. Given the two separate conditionings, stable unCLIP can be used
for text guided image variation. When combined with an unCLIP prior, it can also be used for full text to image generation.
@@ -97,7 +101,7 @@ image
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/text_to_video.md b/docs/source/en/api/pipelines/text_to_video.md
index 7522264e0b58..5eb1dd1a9dbd 100644
--- a/docs/source/en/api/pipelines/text_to_video.md
+++ b/docs/source/en/api/pipelines/text_to_video.md
@@ -18,6 +18,10 @@ specific language governing permissions and limitations under the License.
# Text-to-video
+
+
+
+
[ModelScope Text-to-Video Technical Report](https://arxiv.org/abs/2308.06571) is by Jiuniu Wang, Hangjie Yuan, Dayou Chen, Yingya Zhang, Xiang Wang, Shiwei Zhang.
The abstract from the paper is:
@@ -175,7 +179,7 @@ Check out the [Text or image-to-video](text-img2vid) guide for more details abou
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/text_to_video_zero.md b/docs/source/en/api/pipelines/text_to_video_zero.md
index c6bf30fed7af..44d9a6670af4 100644
--- a/docs/source/en/api/pipelines/text_to_video_zero.md
+++ b/docs/source/en/api/pipelines/text_to_video_zero.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Text2Video-Zero
+
+
+
+
[Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://huggingface.co/papers/2303.13439) is by Levon Khachatryan, Andranik Movsisyan, Vahram Tadevosyan, Roberto Henschel, [Zhangyang Wang](https://www.ece.utexas.edu/people/faculty/atlas-wang), Shant Navasardyan, [Humphrey Shi](https://www.humphreyshi.com).
Text2Video-Zero enables zero-shot video generation using either:
@@ -284,7 +288,7 @@ You can filter out some available DreamBooth-trained models with [this link](htt
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/unclip.md b/docs/source/en/api/pipelines/unclip.md
index f379ffd63f53..943cebdb28a2 100644
--- a/docs/source/en/api/pipelines/unclip.md
+++ b/docs/source/en/api/pipelines/unclip.md
@@ -19,7 +19,7 @@ You can find lucidrains' DALL-E 2 recreation at [lucidrains/DALLE2-pytorch](http
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/unidiffuser.md b/docs/source/en/api/pipelines/unidiffuser.md
index 553a6d300152..802aefea6be5 100644
--- a/docs/source/en/api/pipelines/unidiffuser.md
+++ b/docs/source/en/api/pipelines/unidiffuser.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# UniDiffuser
+
+
+
+
The UniDiffuser model was proposed in [One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale](https://huggingface.co/papers/2303.06555) by Fan Bao, Shen Nie, Kaiwen Xue, Chongxuan Li, Shi Pu, Yaole Wang, Gang Yue, Yue Cao, Hang Su, Jun Zhu.
The abstract from the paper is:
@@ -192,7 +196,7 @@ print(final_prompt)
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/value_guided_sampling.md b/docs/source/en/api/pipelines/value_guided_sampling.md
index d21dbf04d7ee..5aaee9090cef 100644
--- a/docs/source/en/api/pipelines/value_guided_sampling.md
+++ b/docs/source/en/api/pipelines/value_guided_sampling.md
@@ -30,7 +30,7 @@ The script to run the model is available [here](https://github.com/huggingface/d
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
new file mode 100644
index 000000000000..cb856fe0acfc
--- /dev/null
+++ b/docs/source/en/api/pipelines/wan.md
@@ -0,0 +1,465 @@
+
+
+# Wan
+
+
+
+
+
+[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
+
+
+
+## Generating Videos with Wan 2.1
+
+We will first need to install some addtional dependencies.
+
+```shell
+pip install -u ftfy imageio-ffmpeg imageio
+```
+
+### Text to Video Generation
+
+The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
+for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
+
+```python
+from diffusers import WanPipeline
+from diffusers.utils import export_to_video
+
+# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
+model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
+
+pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+pipe.enable_model_cpu_offload()
+
+prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+num_frames = 33
+
+frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
+export_to_video(frames, "wan-t2v.mp4", fps=16)
+```
+
+
+You can improve the quality of the generated video by running the decoding step in full precision.
+
+
+```python
+from diffusers import WanPipeline, AutoencoderKLWan
+from diffusers.utils import export_to_video
+
+model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
+
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+
+# replace this with pipe.to("cuda") if you have sufficient VRAM
+pipe.enable_model_cpu_offload()
+
+prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+num_frames = 33
+
+frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
+export_to_video(frames, "wan-t2v.mp4", fps=16)
+```
+
+### Image to Video Generation
+
+The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
+35GB of VRAM to run.
+
+```python
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
+from diffusers.utils import export_to_video, load_image
+from transformers import CLIPVisionModel
+
+# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
+model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(
+ model_id, subfolder="image_encoder", torch_dtype=torch.float32
+)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = WanImageToVideoPipeline.from_pretrained(
+ model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
+)
+
+# replace this with pipe.to("cuda") if you have sufficient VRAM
+pipe.enable_model_cpu_offload()
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+)
+
+max_area = 480 * 832
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+image = image.resize((width, height))
+
+prompt = (
+ "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+)
+negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+num_frames = 33
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ guidance_scale=5.0,
+).frames[0]
+export_to_video(output, "wan-i2v.mp4", fps=16)
+```
+
+### Video to Video Generation
+
+```python
+import torch
+from diffusers.utils import load_video, export_to_video
+from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
+
+# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
+model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(
+ model_id, subfolder="vae", torch_dtype=torch.float32
+)
+pipe = WanVideoToVideoPipeline.from_pretrained(
+ model_id, vae=vae, torch_dtype=torch.bfloat16
+)
+flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
+pipe.scheduler = UniPCMultistepScheduler.from_config(
+ pipe.scheduler.config, flow_shift=flow_shift
+)
+# change to pipe.to("cuda") if you have sufficient VRAM
+pipe.enable_model_cpu_offload()
+
+prompt = "A robot standing on a mountain top. The sun is setting in the background"
+negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+video = load_video(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
+)
+output = pipe(
+ video=video,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=512,
+ guidance_scale=7.0,
+ strength=0.7,
+).frames[0]
+
+export_to_video(output, "wan-v2v.mp4", fps=16)
+```
+
+## Memory Optimizations for Wan 2.1
+
+Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
+
+We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
+
+### Group Offloading the Transformer and UMT5 Text Encoder
+
+Find more information about group offloading [here](../optimization/memory.md)
+
+#### Block Level Group Offloading
+
+We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
+
+The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
+
+```python
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
+from diffusers.hooks.group_offloading import apply_group_offloading
+from diffusers.utils import export_to_video, load_image
+from transformers import UMT5EncoderModel, CLIPVisionModel
+
+# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
+model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(
+ model_id, subfolder="image_encoder", torch_dtype=torch.float32
+)
+
+text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+
+apply_group_offloading(text_encoder,
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="block_level",
+ num_blocks_per_group=4
+)
+
+transformer.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="block_level",
+ num_blocks_per_group=4,
+)
+pipe = WanImageToVideoPipeline.from_pretrained(
+ model_id,
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ torch_dtype=torch.bfloat16
+)
+# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
+pipe.to("cuda")
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+)
+
+max_area = 720 * 832
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+image = image.resize((width, height))
+
+prompt = (
+ "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+)
+negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+num_frames = 33
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ guidance_scale=5.0,
+).frames[0]
+
+export_to_video(output, "wan-i2v.mp4", fps=16)
+```
+
+#### Block Level Group Offloading with CUDA Streams
+
+We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
+
+In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
+
+```python
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
+from diffusers.hooks.group_offloading import apply_group_offloading
+from diffusers.utils import export_to_video, load_image
+from transformers import UMT5EncoderModel, CLIPVisionModel
+
+# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
+model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(
+ model_id, subfolder="image_encoder", torch_dtype=torch.float32
+)
+
+text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+
+apply_group_offloading(text_encoder,
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="block_level",
+ num_blocks_per_group=4
+)
+
+transformer.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True
+)
+pipe = WanImageToVideoPipeline.from_pretrained(
+ model_id,
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ torch_dtype=torch.bfloat16
+)
+# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
+pipe.to("cuda")
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+)
+
+max_area = 720 * 832
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+image = image.resize((width, height))
+
+prompt = (
+ "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+)
+negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+num_frames = 33
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ guidance_scale=5.0,
+).frames[0]
+
+export_to_video(output, "wan-i2v.mp4", fps=16)
+```
+
+### Applying Layerwise Casting to the Transformer
+
+Find more information about layerwise casting [here](../optimization/memory.md)
+
+In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
+
+This example will require 20GB of VRAM.
+
+```python
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
+from diffusers.hooks.group_offloading import apply_group_offloading
+from diffusers.utils import export_to_video, load_image
+from transformers import UMT5EncoderModel, CLIPVisionModel
+
+model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(
+ model_id, subfolder="image_encoder", torch_dtype=torch.float32
+)
+text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+
+transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
+
+pipe = WanImageToVideoPipeline.from_pretrained(
+ model_id,
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ image_encoder=image_encoder,
+ torch_dtype=torch.bfloat16
+)
+pipe.enable_model_cpu_offload()
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
+
+max_area = 720 * 832
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+image = image.resize((width, height))
+prompt = (
+ "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+)
+negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+num_frames = 33
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+).frames[0]
+export_to_video(output, "wan-i2v.mp4", fps=16)
+```
+
+## Using a Custom Scheduler
+
+Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
+
+```python
+from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline
+
+scheduler_a = FlowMatchEulerDiscreteScheduler(shift=5.0)
+scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=4.0)
+
+pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler=)
+
+# or,
+pipe.scheduler =
+```
+
+## Using Single File Loading with Wan 2.1
+
+The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
+method.
+
+```python
+import torch
+from diffusers import WanPipeline, WanTransformer3DModel
+
+ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
+transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
+
+pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
+```
+
+## Recommendations for Inference
+- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
+- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
+- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
+
+## WanPipeline
+
+[[autodoc]] WanPipeline
+ - all
+ - __call__
+
+## WanImageToVideoPipeline
+
+[[autodoc]] WanImageToVideoPipeline
+ - all
+ - __call__
+
+## WanPipelineOutput
+
+[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md
index 4d90ad46dc64..da6ef2cffc28 100644
--- a/docs/source/en/api/pipelines/wuerstchen.md
+++ b/docs/source/en/api/pipelines/wuerstchen.md
@@ -12,6 +12,10 @@ specific language governing permissions and limitations under the License.
# Würstchen
+
+
+
+
[Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, Mats L. Richter and Christopher Pal and Marc Aubreville.
diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md
index 2fbde9e707ea..2c728cff3c07 100644
--- a/docs/source/en/api/quantization.md
+++ b/docs/source/en/api/quantization.md
@@ -28,6 +28,18 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
[[autodoc]] BitsAndBytesConfig
+## GGUFQuantizationConfig
+
+[[autodoc]] GGUFQuantizationConfig
+
+## QuantoConfig
+
+[[autodoc]] QuantoConfig
+
+## TorchAoConfig
+
+[[autodoc]] TorchAoConfig
+
## DiffusersQuantizer
[[autodoc]] quantizers.base.DiffusersQuantizer
diff --git a/docs/source/en/api/schedulers/ddim_cogvideox.md b/docs/source/en/api/schedulers/ddim_cogvideox.md
new file mode 100644
index 000000000000..d3ff380306c7
--- /dev/null
+++ b/docs/source/en/api/schedulers/ddim_cogvideox.md
@@ -0,0 +1,19 @@
+
+
+# CogVideoXDDIMScheduler
+
+`CogVideoXDDIMScheduler` is based on [Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502), specifically for CogVideoX models.
+
+## CogVideoXDDIMScheduler
+
+[[autodoc]] CogVideoXDDIMScheduler
diff --git a/docs/source/en/api/schedulers/multistep_dpm_solver_cogvideox.md b/docs/source/en/api/schedulers/multistep_dpm_solver_cogvideox.md
new file mode 100644
index 000000000000..bce09a15f543
--- /dev/null
+++ b/docs/source/en/api/schedulers/multistep_dpm_solver_cogvideox.md
@@ -0,0 +1,19 @@
+
+
+# CogVideoXDPMScheduler
+
+`CogVideoXDPMScheduler` is based on [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095), specifically for CogVideoX models.
+
+## CogVideoXDPMScheduler
+
+[[autodoc]] CogVideoXDPMScheduler
diff --git a/docs/source/en/api/utilities.md b/docs/source/en/api/utilities.md
index d4f4d7d7964f..b653cdafbb28 100644
--- a/docs/source/en/api/utilities.md
+++ b/docs/source/en/api/utilities.md
@@ -41,3 +41,11 @@ Utility and helper functions for working with 🤗 Diffusers.
## randn_tensor
[[autodoc]] utils.torch_utils.randn_tensor
+
+## apply_layerwise_casting
+
+[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
+
+## apply_group_offloading
+
+[[autodoc]] hooks.group_offloading.apply_group_offloading
diff --git a/docs/source/en/community_projects.md b/docs/source/en/community_projects.md
index 4ab1829871c8..dcca0a504d86 100644
--- a/docs/source/en/community_projects.md
+++ b/docs/source/en/community_projects.md
@@ -79,4 +79,8 @@ Happy exploring, and thank you for being part of the Diffusers community!
Stable Diffusion Server
A server configured for Inpainting/Generation/img2img with one stable diffusion model
+
+ Model Search
+ Search models on Civitai and Hugging Face
+
diff --git a/docs/source/en/conceptual/evaluation.md b/docs/source/en/conceptual/evaluation.md
index 8dfbc8f2ac80..131b888e7a72 100644
--- a/docs/source/en/conceptual/evaluation.md
+++ b/docs/source/en/conceptual/evaluation.md
@@ -16,6 +16,11 @@ specific language governing permissions and limitations under the License.
+> [!TIP]
+> This document has now grown outdated given the emergence of existing evaluation frameworks for diffusion models for image generation. Please check
+> out works like [HEIM](https://crfm.stanford.edu/helm/heim/latest/), [T2I-Compbench](https://arxiv.org/abs/2307.06350),
+> [GenEval](https://arxiv.org/abs/2310.11513).
+
Evaluation of generative models like [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) is subjective in nature. But as practitioners and researchers, we often have to make careful choices amongst many different possibilities. So, when working with different generative models (like GANs, Diffusion, etc.), how do we choose one over the other?
Qualitative evaluation of such models can be error-prone and might incorrectly influence a decision.
@@ -181,7 +186,7 @@ Then we load the [v1-5 checkpoint](https://huggingface.co/stable-diffusion-v1-5/
```python
model_ckpt_1_5 = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=weight_dtype).to(device)
+sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=torch.float16).to("cuda")
images_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
```
@@ -280,7 +285,7 @@ from diffusers import StableDiffusionInstructPix2PixPipeline
instruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16
-).to(device)
+).to("cuda")
```
Now, we perform the edits:
@@ -326,9 +331,9 @@ from transformers import (
clip_id = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(clip_id)
-text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(device)
+text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to("cuda")
image_processor = CLIPImageProcessor.from_pretrained(clip_id)
-image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device)
+image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to("cuda")
```
Notice that we are using a particular CLIP checkpoint, i.e., `openai/clip-vit-large-patch14`. This is because the Stable Diffusion pre-training was performed with this CLIP variant. For more details, refer to the [documentation](https://huggingface.co/docs/transformers/model_doc/clip).
@@ -350,7 +355,7 @@ class DirectionalSimilarity(nn.Module):
def preprocess_image(self, image):
image = self.image_processor(image, return_tensors="pt")["pixel_values"]
- return {"pixel_values": image.to(device)}
+ return {"pixel_values": image.to("cuda")}
def tokenize_text(self, text):
inputs = self.tokenizer(
@@ -360,7 +365,7 @@ class DirectionalSimilarity(nn.Module):
truncation=True,
return_tensors="pt",
)
- return {"input_ids": inputs.input_ids.to(device)}
+ return {"input_ids": inputs.input_ids.to("cuda")}
def encode_image(self, image):
preprocessed_image = self.preprocess_image(image)
@@ -459,6 +464,7 @@ with ZipFile(local_filepath, "r") as zipper:
```python
from PIL import Image
import os
+import numpy as np
dataset_path = "sample-imagenet-images"
image_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])
@@ -477,6 +483,7 @@ Now that the images are loaded, let's apply some lightweight pre-processing on t
```python
from torchvision.transforms import functional as F
+import torch
def preprocess_image(image):
@@ -498,6 +505,10 @@ dit_pipeline = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=
dit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)
dit_pipeline = dit_pipeline.to("cuda")
+seed = 0
+generator = torch.manual_seed(seed)
+
+
words = [
"cassette player",
"chainsaw",
diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md
new file mode 100644
index 000000000000..865aaba5ebb6
--- /dev/null
+++ b/docs/source/en/hybrid_inference/api_reference.md
@@ -0,0 +1,9 @@
+# Hybrid Inference API Reference
+
+## Remote Decode
+
+[[autodoc]] utils.remote_utils.remote_decode
+
+## Remote Encode
+
+[[autodoc]] utils.remote_utils.remote_encode
diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md
new file mode 100644
index 000000000000..b44393c77cbd
--- /dev/null
+++ b/docs/source/en/hybrid_inference/overview.md
@@ -0,0 +1,60 @@
+
+
+# Hybrid Inference
+
+**Empowering local AI builders with Hybrid Inference**
+
+
+> [!TIP]
+> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
+> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
+
+
+
+## Why use Hybrid Inference?
+
+Hybrid Inference offers a fast and simple way to offload local generation requirements.
+
+- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
+- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
+- 💰 **Cost Effective:** It's free! 🤑
+- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
+- 🔧 **Developer-Friendly:** Simple requests, fast responses.
+
+---
+
+## Available Models
+
+* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
+* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
+* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
+
+---
+
+## Integrations
+
+* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
+* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
+
+## Changelog
+
+- March 10 2025: Added VAE encode
+- March 2 2025: Initial release with VAE decoding
+
+## Contents
+
+The documentation is organized into three sections:
+
+* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
+* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
+* **API Reference** Dive into task-specific settings and parameters.
diff --git a/docs/source/en/hybrid_inference/vae_decode.md b/docs/source/en/hybrid_inference/vae_decode.md
new file mode 100644
index 000000000000..1457090550c7
--- /dev/null
+++ b/docs/source/en/hybrid_inference/vae_decode.md
@@ -0,0 +1,345 @@
+# Getting Started: VAE Decode with Hybrid Inference
+
+VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
+
+## Memory
+
+These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
+
+For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
+
+SD v1.5
+
+| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
+| --- | --- | --- | --- | --- | --- |
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
+| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
+| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
+| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
+| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
+
+
+
+SDXL
+
+| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
+| --- | --- | --- | --- | --- | --- |
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
+| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
+| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
+| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
+| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
+
+
+
+## Available VAEs
+
+| | **Endpoint** | **Model** |
+|:-:|:-----------:|:--------:|
+| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
+| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
+| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
+| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
+
+
+> [!TIP]
+> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
+
+
+## Code
+
+> [!TIP]
+> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
+
+
+A helper method simplifies interacting with Hybrid Inference.
+
+```python
+from diffusers.utils.remote_utils import remote_decode
+```
+
+### Basic example
+
+Here, we show how to use the remote VAE on random tensors.
+
+Code
+
+```python
+image = remote_decode(
+ endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
+ scaling_factor=0.18215,
+)
+```
+
+
+
+
+
+
+
+Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
+
+Code
+
+```python
+image = remote_decode(
+ endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
+ height=1024,
+ width=1024,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+```
+
+
+
+
+
+
+
+Finally, an example for HunyuanVideo.
+
+Code
+
+```python
+video = remote_decode(
+ endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
+ output_type="mp4",
+)
+with open("video.mp4", "wb") as f:
+ f.write(video)
+```
+
+
+
+
+
+
+
+
+
+
+### Generation
+
+But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
+
+Code
+
+```python
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ vae=None,
+).to("cuda")
+
+prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
+
+latent = pipe(
+ prompt=prompt,
+ output_type="latent",
+).images
+image = remote_decode(
+ endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ scaling_factor=0.18215,
+)
+image.save("test.jpg")
+```
+
+
+
+
+
+
+
+Here’s another example with Flux.
+
+Code
+
+```python
+from diffusers import FluxPipeline
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ torch_dtype=torch.bfloat16,
+ vae=None,
+).to("cuda")
+
+prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
+
+latent = pipe(
+ prompt=prompt,
+ guidance_scale=0.0,
+ num_inference_steps=4,
+ output_type="latent",
+).images
+image = remote_decode(
+ endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ height=1024,
+ width=1024,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+image.save("test.jpg")
+```
+
+
+
+
+
+
+
+Here’s an example with HunyuanVideo.
+
+Code
+
+```python
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+
+model_id = "hunyuanvideo-community/HunyuanVideo"
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+)
+pipe = HunyuanVideoPipeline.from_pretrained(
+ model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
+).to("cuda")
+
+latent = pipe(
+ prompt="A cat walks on the grass, realistic",
+ height=320,
+ width=512,
+ num_frames=61,
+ num_inference_steps=30,
+ output_type="latent",
+).frames
+
+video = remote_decode(
+ endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ output_type="mp4",
+)
+
+if isinstance(video, bytes):
+ with open("video.mp4", "wb") as f:
+ f.write(video)
+```
+
+
+
+
+
+
+
+
+
+
+### Queueing
+
+One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
+
+
+Code
+
+```python
+import queue
+import threading
+from IPython.display import display
+from diffusers import StableDiffusionPipeline
+
+def decode_worker(q: queue.Queue):
+ while True:
+ item = q.get()
+ if item is None:
+ break
+ image = remote_decode(
+ endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=item,
+ scaling_factor=0.18215,
+ )
+ display(image)
+ q.task_done()
+
+q = queue.Queue()
+thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
+thread.start()
+
+def decode(latent: torch.Tensor):
+ q.put(latent)
+
+prompts = [
+ "Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
+ "Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
+ "Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
+ "Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
+ "A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
+ "Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
+]
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "Lykon/dreamshaper-8",
+ torch_dtype=torch.float16,
+ vae=None,
+).to("cuda")
+
+pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
+pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+_ = pipe(
+ prompt=prompts[0],
+ output_type="latent",
+)
+
+for prompt in prompts:
+ latent = pipe(
+ prompt=prompt,
+ output_type="latent",
+ ).images
+ decode(latent)
+
+q.put(None)
+thread.join()
+```
+
+
+
+
+
+
+
+
+
+
+## Integrations
+
+* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
+* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md
new file mode 100644
index 000000000000..dd285fa25c03
--- /dev/null
+++ b/docs/source/en/hybrid_inference/vae_encode.md
@@ -0,0 +1,183 @@
+# Getting Started: VAE Encode with Hybrid Inference
+
+VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
+
+## Memory
+
+These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
+
+For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
+
+SD v1.5
+
+| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
+|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
+| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
+| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
+| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
+| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
+| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
+| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
+| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
+| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
+| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
+| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
+| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
+| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
+
+
+
+
+SDXL
+
+| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
+|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
+| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
+| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
+| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
+| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
+| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
+| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
+| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
+| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
+| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
+| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
+| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
+| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
+
+
+
+## Available VAEs
+
+| | **Endpoint** | **Model** |
+|:-:|:-----------:|:--------:|
+| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
+| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
+| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
+
+
+> [!TIP]
+> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
+
+
+## Code
+
+> [!TIP]
+> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
+
+
+A helper method simplifies interacting with Hybrid Inference.
+
+```python
+from diffusers.utils.remote_utils import remote_encode
+```
+
+### Basic example
+
+Let's encode an image, then decode it to demonstrate.
+
+
+
+
+
+Code
+
+```python
+from diffusers.utils import load_image
+from diffusers.utils.remote_utils import remote_decode
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
+
+latent = remote_encode(
+ endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+
+decoded = remote_decode(
+ endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+```
+
+
+
+
+
+
+
+
+### Generation
+
+Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
+
+Code
+
+```python
+import torch
+from diffusers import StableDiffusionImg2ImgPipeline
+from diffusers.utils import load_image
+from diffusers.utils.remote_utils import remote_decode, remote_encode
+
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ vae=None,
+).to("cuda")
+
+init_image = load_image(
+ "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+)
+init_image = init_image.resize((768, 512))
+
+init_latent = remote_encode(
+ endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
+ image=init_image,
+ scaling_factor=0.18215,
+)
+
+prompt = "A fantasy landscape, trending on artstation"
+latent = pipe(
+ prompt=prompt,
+ image=init_latent,
+ strength=0.75,
+ output_type="latent",
+).images
+
+image = remote_decode(
+ endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ scaling_factor=0.18215,
+)
+image.save("fantasy_landscape.jpg")
+```
+
+
+
+
+
+
+
+## Integrations
+
+* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
+* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
diff --git a/docs/source/en/installation.md b/docs/source/en/installation.md
index 74cfa70d70fc..570fac096862 100644
--- a/docs/source/en/installation.md
+++ b/docs/source/en/installation.md
@@ -23,32 +23,60 @@ You should install 🤗 Diffusers in a [virtual environment](https://docs.python
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
-Start by creating a virtual environment in your project directory:
+Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
+
+
+
```bash
-python -m venv .env
+uv venv my-env
+source my-env/bin/activate
```
-Activate the virtual environment:
+
+
```bash
-source .env/bin/activate
+python -m venv my-env
+source my-env/bin/activate
```
-You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:
+
+
+
+You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
-Note - PyTorch only supports Python 3.8 - 3.11 on Windows.
+
+PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
+
+```bash
+uv install diffusers["torch"] transformers
+```
+
+You can also install Diffusers with pip.
+
```bash
pip install diffusers["torch"] transformers
```
+
+
+Install Diffusers with uv.
+
+```bash
+uv pip install diffusers["flax"] transformers
+```
+
+You can also install Diffusers with pip.
+
```bash
pip install diffusers["flax"] transformers
```
+
@@ -133,10 +161,10 @@ Your Python environment will find the `main` version of 🤗 Diffusers on the ne
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
-Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `True` and 🤗 Diffusers will only load previously downloaded files in the cache.
+Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
```shell
-export HF_HUB_OFFLINE=True
+export HF_HUB_OFFLINE=1
```
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
@@ -151,14 +179,16 @@ Telemetry is only sent when loading models and pipelines from the Hub,
and it is not collected if you're loading local files.
We understand that not everyone wants to share additional information,and we respect your privacy.
-You can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
+You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
On Linux/MacOS:
+
```bash
-export DISABLE_TELEMETRY=YES
+export HF_HUB_DISABLE_TELEMETRY=1
```
On Windows:
+
```bash
-set DISABLE_TELEMETRY=YES
+set HF_HUB_DISABLE_TELEMETRY=1
```
diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md
index a2150f9aa0b7..fd72957471c0 100644
--- a/docs/source/en/optimization/memory.md
+++ b/docs/source/en/optimization/memory.md
@@ -158,6 +158,103 @@ In order to properly offload models after they're called, it is required to run
+## Group offloading
+
+Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.
+
+To enable group offloading, call the [`~ModelMixin.enable_group_offload`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]:
+
+```python
+import torch
+from diffusers import CogVideoXPipeline
+from diffusers.hooks import apply_group_offloading
+from diffusers.utils import export_to_video
+
+# Load the pipeline
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+
+# We can utilize the enable_group_offload method for Diffusers model implementations
+pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
+
+# For any other model implementations, the apply_group_offloading function can be used
+apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
+apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
+
+prompt = (
+ "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ "atmosphere of this unique musical performance."
+)
+video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+export_to_video(video, "output.mp4", fps=8)
+```
+
+Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
+
+
+
+- Group offloading may not work with all models out-of-the-box. If the forward implementations of the model contain weight-dependent device-casting of inputs, it may clash with the offloading mechanism's handling of device-casting.
+- The `offload_type` parameter can be set to either `block_level` or `leaf_level`. `block_level` offloads groups of `torch::nn::ModuleList` or `torch::nn:Sequential` modules based on a configurable attribute `num_blocks_per_group`. For example, if you set `num_blocks_per_group=2` on a standard transformer model containing 40 layers, it will onload/offload 2 layers at a time for a total of 20 onload/offloads. This drastically reduces the VRAM requirements. `leaf_level` offloads individual layers at the lowest level, which is equivalent to sequential offloading. However, unlike sequential offloading, group offloading can be made much faster when using streams, with minimal compromise to end-to-end generation time.
+- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
+- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
+- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
+
+For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
+
+
+
+## FP8 layerwise weight-casting
+
+PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
+
+Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
+
+```python
+import torch
+from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
+from diffusers.utils import export_to_video
+
+model_id = "THUDM/CogVideoX-5b"
+
+# Load the model in bfloat16 and enable layerwise casting
+transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
+
+# Load the pipeline
+pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+prompt = (
+ "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ "atmosphere of this unique musical performance."
+)
+video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+export_to_video(video, "output.mp4", fps=8)
+```
+
+In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
+
+However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
+
+
+
+- Layerwise casting may not work with all models out-of-the-box. Sometimes, the forward implementations of the model might contain internal typecasting of weight values. Such implementations are not supported due to the currently simplistic implementation of layerwise casting, which assumes that the forward pass is independent of the weight precision and that the input dtypes are always in `compute_dtype`. An example of an incompatible implementation can be found [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299).
+- Layerwise casting may fail on custom modeling implementations that make use of [PEFT](https://github.com/huggingface/peft) layers. Some minimal checks to handle this case is implemented but is not extensively tested or guaranteed to work in all cases.
+- It can be also be applied partially to specific layers of a model. Partially applying layerwise casting can either be done manually by calling the `apply_layerwise_casting` function on specific internal modules, or by specifying the `skip_modules_pattern` and `skip_modules_classes` parameters for a root module. These parameters are particularly useful for layers such as normalization and modulation.
+
+
+
## Channels-last memory format
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
diff --git a/docs/source/en/optimization/neuron.md b/docs/source/en/optimization/neuron.md
new file mode 100644
index 000000000000..b10050e64d7f
--- /dev/null
+++ b/docs/source/en/optimization/neuron.md
@@ -0,0 +1,61 @@
+
+
+# AWS Neuron
+
+Diffusers functionalities are available on [AWS Inf2 instances](https://aws.amazon.com/ec2/instance-types/inf2/), which are EC2 instances powered by [Neuron machine learning accelerators](https://aws.amazon.com/machine-learning/inferentia/). These instances aim to provide better compute performance (higher throughput, lower latency) with good cost-efficiency, making them good candidates for AWS users to deploy diffusion models to production.
+
+[Optimum Neuron](https://huggingface.co/docs/optimum-neuron/en/index) is the interface between Hugging Face libraries and AWS Accelerators, including AWS [Trainium](https://aws.amazon.com/machine-learning/trainium/) and AWS [Inferentia](https://aws.amazon.com/machine-learning/inferentia/). It supports many of the features in Diffusers with similar APIs, so it is easier to learn if you're already familiar with Diffusers. Once you have created an AWS Inf2 instance, install Optimum Neuron.
+
+```bash
+python -m pip install --upgrade-strategy eager optimum[neuronx]
+```
+
+
+
+We provide pre-built [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) (DLAMI) and Optimum Neuron containers for Amazon SageMaker. It's recommended to correctly set up your environment.
+
+
+
+The example below demonstrates how to generate images with the Stable Diffusion XL model on an inf2.8xlarge instance (you can switch to cheaper inf2.xlarge instances once the model is compiled). To generate some images, use the [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] class, which is similar to the [`StableDiffusionXLPipeline`] class in Diffusers.
+
+Unlike Diffusers, you need to compile models in the pipeline to the Neuron format, `.neuron`. Launch the following command to export the model to the `.neuron` format.
+
+```bash
+optimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \
+ --batch_size 1 \
+ --height 1024 `# height in pixels of generated image, eg. 768, 1024` \
+ --width 1024 `# width in pixels of generated image, eg. 768, 1024` \
+ --num_images_per_prompt 1 `# number of images to generate per prompt, defaults to 1` \
+ --auto_cast matmul `# cast only matrix multiplication operations` \
+ --auto_cast_type bf16 `# cast operations from FP32 to BF16` \
+ sd_neuron_xl/
+```
+
+Now generate some images with the pre-compiled SDXL model.
+
+```python
+>>> from optimum.neuron import NeuronStableDiffusionXLPipeline
+
+>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained("sd_neuron_xl/")
+>>> prompt = "a pig with wings flying in floating US dollar banknotes in the air, skyscrapers behind, warm color palette, muted colors, detailed, 8k"
+>>> image = stable_diffusion_xl(prompt).images[0]
+```
+
+
+
+Feel free to check out more guides and examples on different use cases from the Optimum Neuron [documentation](https://huggingface.co/docs/optimum-neuron/en/inference_tutorials/stable_diffusion#generate-images-with-stable-diffusion-models-on-aws-inferentia)!
diff --git a/docs/source/en/optimization/para_attn.md b/docs/source/en/optimization/para_attn.md
new file mode 100644
index 000000000000..94b0d5ce3af4
--- /dev/null
+++ b/docs/source/en/optimization/para_attn.md
@@ -0,0 +1,497 @@
+# ParaAttention
+
+
+
+
+
+
+
+
+
+Large image and video generation models, such as [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo), can be an inference challenge for real-time applications and deployment because of their size.
+
+[ParaAttention](https://github.com/chengzeyi/ParaAttention) is a library that implements **context parallelism** and **first block cache**, and can be combined with other techniques (torch.compile, fp8 dynamic quantization), to accelerate inference.
+
+This guide will show you how to apply ParaAttention to FLUX.1-dev and HunyuanVideo on NVIDIA L20 GPUs.
+No optimizations are applied for our baseline benchmark, except for HunyuanVideo to avoid out-of-memory errors.
+
+Our baseline benchmark shows that FLUX.1-dev is able to generate a 1024x1024 resolution image in 28 steps in 26.36 seconds, and HunyuanVideo is able to generate 129 frames at 720p resolution in 30 steps in 3675.71 seconds.
+
+> [!TIP]
+> For even faster inference with context parallelism, try using NVIDIA A100 or H100 GPUs (if available) with NVLink support, especially when there is a large number of GPUs.
+
+## First Block Cache
+
+Caching the output of the transformers blocks in the model and reusing them in the next inference steps reduces the computation cost and makes inference faster.
+
+However, it is hard to decide when to reuse the cache to ensure quality generated images or videos. ParaAttention directly uses the **residual difference of the first transformer block output** to approximate the difference among model outputs. When the difference is small enough, the residual difference of previous inference steps is reused. In other words, the denoising step is skipped.
+
+This achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality.
+
+
+
+ How AdaCache works, First Block Cache is a variant of it
+
+
+
+
+
+To apply first block cache on FLUX.1-dev, call `apply_cache_on_pipe` as shown below. 0.08 is the default residual difference value for FLUX models.
+
+```python
+import time
+import torch
+from diffusers import FluxPipeline
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(pipe, residual_diff_threshold=0.08)
+
+# Enable memory savings
+# pipe.enable_model_cpu_offload()
+# pipe.enable_sequential_cpu_offload()
+
+begin = time.time()
+image = pipe(
+ "A cat holding a sign that says hello world",
+ num_inference_steps=28,
+).images[0]
+end = time.time()
+print(f"Time: {end - begin:.2f}s")
+
+print("Saving image to flux.png")
+image.save("flux.png")
+```
+
+| Optimizations | Original | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 |
+| - | - | - | - | - | - |
+| Preview |  |  |  |  |  |
+| Wall Time (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 |
+
+First Block Cache reduced the inference speed to 17.01 seconds compared to the baseline, or 1.55x faster, while maintaining nearly zero quality loss.
+
+
+
+
+To apply First Block Cache on HunyuanVideo, `apply_cache_on_pipe` as shown below. 0.06 is the default residual difference value for HunyuanVideo models.
+
+```python
+import time
+import torch
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+from diffusers.utils import export_to_video
+
+model_id = "tencent/HunyuanVideo"
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ revision="refs/pr/18",
+)
+pipe = HunyuanVideoPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=torch.float16,
+ revision="refs/pr/18",
+).to("cuda")
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(pipe, residual_diff_threshold=0.6)
+
+pipe.vae.enable_tiling()
+
+begin = time.time()
+output = pipe(
+ prompt="A cat walks on the grass, realistic",
+ height=720,
+ width=1280,
+ num_frames=129,
+ num_inference_steps=30,
+).frames[0]
+end = time.time()
+print(f"Time: {end - begin:.2f}s")
+
+print("Saving video to hunyuan_video.mp4")
+export_to_video(output, "hunyuan_video.mp4", fps=15)
+```
+
+
+
+ Your browser does not support the video tag.
+
+
+ HunyuanVideo without FBCache
+
+
+
+ Your browser does not support the video tag.
+
+
+ HunyuanVideo with FBCache
+
+First Block Cache reduced the inference speed to 2271.06 seconds compared to the baseline, or 1.62x faster, while maintaining nearly zero quality loss.
+
+
+
+
+## fp8 quantization
+
+fp8 with dynamic quantization further speeds up inference and reduces memory usage. Both the activations and weights must be quantized in order to use the 8-bit [NVIDIA Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/).
+
+Use `float8_weight_only` and `float8_dynamic_activation_float8_weight` to quantize the text encoder and transformer model.
+
+The default quantization method is per tensor quantization, but if your GPU supports row-wise quantization, you can also try it for better accuracy.
+
+Install [torchao](https://github.com/pytorch/ao/tree/main) with the command below.
+
+```bash
+pip3 install -U torch torchao
+```
+
+[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) with `mode="max-autotune-no-cudagraphs"` or `mode="max-autotune"` selects the best kernel for performance. Compilation can take a long time if it's the first time the model is called, but it is worth it once the model has been compiled.
+
+This example only quantizes the transformer model, but you can also quantize the text encoder to reduce memory usage even more.
+
+> [!TIP]
+> Dynamic quantization can significantly change the distribution of the model output, so you need to change the `residual_diff_threshold` to a larger value for it to take effect.
+
+
+
+
+```python
+import time
+import torch
+from diffusers import FluxPipeline
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(
+ pipe,
+ residual_diff_threshold=0.12, # Use a larger value to make the cache take effect
+)
+
+from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
+
+quantize_(pipe.text_encoder, float8_weight_only())
+quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
+pipe.transformer = torch.compile(
+ pipe.transformer, mode="max-autotune-no-cudagraphs",
+)
+
+# Enable memory savings
+# pipe.enable_model_cpu_offload()
+# pipe.enable_sequential_cpu_offload()
+
+for i in range(2):
+ begin = time.time()
+ image = pipe(
+ "A cat holding a sign that says hello world",
+ num_inference_steps=28,
+ ).images[0]
+ end = time.time()
+ if i == 0:
+ print(f"Warm up time: {end - begin:.2f}s")
+ else:
+ print(f"Time: {end - begin:.2f}s")
+
+print("Saving image to flux.png")
+image.save("flux.png")
+```
+
+fp8 dynamic quantization and torch.compile reduced the inference speed to 7.56 seconds compared to the baseline, or 3.48x faster.
+
+
+
+
+```python
+import time
+import torch
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+from diffusers.utils import export_to_video
+
+model_id = "tencent/HunyuanVideo"
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ revision="refs/pr/18",
+)
+pipe = HunyuanVideoPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=torch.float16,
+ revision="refs/pr/18",
+).to("cuda")
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(pipe)
+
+from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
+
+quantize_(pipe.text_encoder, float8_weight_only())
+quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
+pipe.transformer = torch.compile(
+ pipe.transformer, mode="max-autotune-no-cudagraphs",
+)
+
+# Enable memory savings
+pipe.vae.enable_tiling()
+# pipe.enable_model_cpu_offload()
+# pipe.enable_sequential_cpu_offload()
+
+for i in range(2):
+ begin = time.time()
+ output = pipe(
+ prompt="A cat walks on the grass, realistic",
+ height=720,
+ width=1280,
+ num_frames=129,
+ num_inference_steps=1 if i == 0 else 30,
+ ).frames[0]
+ end = time.time()
+ if i == 0:
+ print(f"Warm up time: {end - begin:.2f}s")
+ else:
+ print(f"Time: {end - begin:.2f}s")
+
+print("Saving video to hunyuan_video.mp4")
+export_to_video(output, "hunyuan_video.mp4", fps=15)
+```
+
+A NVIDIA L20 GPU only has 48GB memory and could face out-of-memory (OOM) errors after compilation and if `enable_model_cpu_offload` isn't called because HunyuanVideo has very large activation tensors when running with high resolution and large number of frames. For GPUs with less than 80GB of memory, you can try reducing the resolution and number of frames to avoid OOM errors.
+
+Large video generation models are usually bottlenecked by the attention computations rather than the fully connected layers. These models don't significantly benefit from quantization and torch.compile.
+
+
+
+
+## Context Parallelism
+
+Context Parallelism parallelizes inference and scales with multiple GPUs. The ParaAttention compositional design allows you to combine Context Parallelism with First Block Cache and dynamic quantization.
+
+> [!TIP]
+> Refer to the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main) repository for detailed instructions and examples of how to scale inference with multiple GPUs.
+
+If the inference process needs to be persistent and serviceable, it is suggested to use [torch.multiprocessing](https://pytorch.org/docs/stable/multiprocessing.html) to write your own inference processor. This can eliminate the overhead of launching the process and loading and recompiling the model.
+
+
+
+
+The code sample below combines First Block Cache, fp8 dynamic quantization, torch.compile, and Context Parallelism for the fastest inference speed.
+
+```python
+import time
+import torch
+import torch.distributed as dist
+from diffusers import FluxPipeline
+
+dist.init_process_group()
+
+torch.cuda.set_device(dist.get_rank())
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+from para_attn.context_parallel import init_context_parallel_mesh
+from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
+from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
+
+mesh = init_context_parallel_mesh(
+ pipe.device.type,
+ max_ring_dim_size=2,
+)
+parallelize_pipe(
+ pipe,
+ mesh=mesh,
+)
+parallelize_vae(pipe.vae, mesh=mesh._flatten())
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(
+ pipe,
+ residual_diff_threshold=0.12, # Use a larger value to make the cache take effect
+)
+
+from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
+
+quantize_(pipe.text_encoder, float8_weight_only())
+quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
+torch._inductor.config.reorder_for_compute_comm_overlap = True
+pipe.transformer = torch.compile(
+ pipe.transformer, mode="max-autotune-no-cudagraphs",
+)
+
+# Enable memory savings
+# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
+# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())
+
+for i in range(2):
+ begin = time.time()
+ image = pipe(
+ "A cat holding a sign that says hello world",
+ num_inference_steps=28,
+ output_type="pil" if dist.get_rank() == 0 else "pt",
+ ).images[0]
+ end = time.time()
+ if dist.get_rank() == 0:
+ if i == 0:
+ print(f"Warm up time: {end - begin:.2f}s")
+ else:
+ print(f"Time: {end - begin:.2f}s")
+
+if dist.get_rank() == 0:
+ print("Saving image to flux.png")
+ image.save("flux.png")
+
+dist.destroy_process_group()
+```
+
+Save to `run_flux.py` and launch it with [torchrun](https://pytorch.org/docs/stable/elastic/run.html).
+
+```bash
+# Use --nproc_per_node to specify the number of GPUs
+torchrun --nproc_per_node=2 run_flux.py
+```
+
+Inference speed is reduced to 8.20 seconds compared to the baseline, or 3.21x faster, with 2 NVIDIA L20 GPUs. On 4 L20s, inference speed is 3.90 seconds, or 6.75x faster.
+
+
+
+
+The code sample below combines First Block Cache and Context Parallelism for the fastest inference speed.
+
+```python
+import time
+import torch
+import torch.distributed as dist
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+from diffusers.utils import export_to_video
+
+dist.init_process_group()
+
+torch.cuda.set_device(dist.get_rank())
+
+model_id = "tencent/HunyuanVideo"
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ revision="refs/pr/18",
+)
+pipe = HunyuanVideoPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=torch.float16,
+ revision="refs/pr/18",
+).to("cuda")
+
+from para_attn.context_parallel import init_context_parallel_mesh
+from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
+from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
+
+mesh = init_context_parallel_mesh(
+ pipe.device.type,
+)
+parallelize_pipe(
+ pipe,
+ mesh=mesh,
+)
+parallelize_vae(pipe.vae, mesh=mesh._flatten())
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(pipe)
+
+# from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
+#
+# torch._inductor.config.reorder_for_compute_comm_overlap = True
+#
+# quantize_(pipe.text_encoder, float8_weight_only())
+# quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
+# pipe.transformer = torch.compile(
+# pipe.transformer, mode="max-autotune-no-cudagraphs",
+# )
+
+# Enable memory savings
+pipe.vae.enable_tiling()
+# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
+# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())
+
+for i in range(2):
+ begin = time.time()
+ output = pipe(
+ prompt="A cat walks on the grass, realistic",
+ height=720,
+ width=1280,
+ num_frames=129,
+ num_inference_steps=1 if i == 0 else 30,
+ output_type="pil" if dist.get_rank() == 0 else "pt",
+ ).frames[0]
+ end = time.time()
+ if dist.get_rank() == 0:
+ if i == 0:
+ print(f"Warm up time: {end - begin:.2f}s")
+ else:
+ print(f"Time: {end - begin:.2f}s")
+
+if dist.get_rank() == 0:
+ print("Saving video to hunyuan_video.mp4")
+ export_to_video(output, "hunyuan_video.mp4", fps=15)
+
+dist.destroy_process_group()
+```
+
+Save to `run_hunyuan_video.py` and launch it with [torchrun](https://pytorch.org/docs/stable/elastic/run.html).
+
+```bash
+# Use --nproc_per_node to specify the number of GPUs
+torchrun --nproc_per_node=8 run_hunyuan_video.py
+```
+
+Inference speed is reduced to 649.23 seconds compared to the baseline, or 5.66x faster, with 8 NVIDIA L20 GPUs.
+
+
+
+
+## Benchmarks
+
+
+
+
+| GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup |
+| - | - | - | - | - |
+| NVIDIA L20 | 1 | Baseline | 26.36 | 1.00x |
+| NVIDIA L20 | 1 | FBCache (rdt=0.08) | 17.01 | 1.55x |
+| NVIDIA L20 | 1 | FP8 DQ | 13.40 | 1.96x |
+| NVIDIA L20 | 1 | FBCache (rdt=0.12) + FP8 DQ | 7.56 | 3.48x |
+| NVIDIA L20 | 2 | FBCache (rdt=0.12) + FP8 DQ + CP | 4.92 | 5.35x |
+| NVIDIA L20 | 4 | FBCache (rdt=0.12) + FP8 DQ + CP | 3.90 | 6.75x |
+
+
+
+
+| GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup |
+| - | - | - | - | - |
+| NVIDIA L20 | 1 | Baseline | 3675.71 | 1.00x |
+| NVIDIA L20 | 1 | FBCache | 2271.06 | 1.62x |
+| NVIDIA L20 | 2 | FBCache + CP | 1132.90 | 3.24x |
+| NVIDIA L20 | 4 | FBCache + CP | 718.15 | 5.12x |
+| NVIDIA L20 | 8 | FBCache + CP | 649.23 | 5.66x |
+
+
+
diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md
index f272346aa2e2..266daa01935e 100644
--- a/docs/source/en/quantization/bitsandbytes.md
+++ b/docs/source/en/quantization/bitsandbytes.md
@@ -17,6 +17,12 @@ specific language governing permissions and limitations under the License.
4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
+This guide demonstrates how quantization can enable running
+[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
+on less than 16GB of VRAM and even on a free Google
+Colab instance.
+
+
To use bitsandbytes, make sure you have the following libraries installed:
@@ -31,82 +37,167 @@ Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixi
Quantizing a model in 8-bit halves the memory-usage:
+bitsandbytes is supported in both Transformers and Diffusers, so you can quantize both the
+[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].
+
+For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`.
+
+> [!TIP]
+> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers.
+
```py
-from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
+from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
-quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+from diffusers import FluxTransformer2DModel
+from transformers import T5EncoderModel
-model_8bit = FluxTransformer2DModel.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
- subfolder="transformer",
- quantization_config=quantization_config
+quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
+
+text_encoder_2_8bit = T5EncoderModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="text_encoder_2",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
)
-```
-By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
-```py
-from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+transformer_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+```
-quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
-model_8bit = FluxTransformer2DModel.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
+```diff
+transformer_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
subfolder="transformer",
- quantization_config=quantization_config,
- torch_dtype=torch.float32
+ quantization_config=quant_config,
++ torch_dtype=torch.float32,
)
-model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
```
-Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights.
+Let's generate an image using our quantized models.
+
+Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the
+CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
```py
-from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ transformer=transformer_8bit,
+ text_encoder_2=text_encoder_2_8bit,
+ torch_dtype=torch.float16,
+ device_map="auto",
+)
-quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+pipe_kwargs = {
+ "prompt": "A cat holding a sign that says hello world",
+ "height": 1024,
+ "width": 1024,
+ "guidance_scale": 3.5,
+ "num_inference_steps": 50,
+ "max_sequence_length": 512,
+}
-model_8bit = FluxTransformer2DModel.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
- subfolder="transformer",
- quantization_config=quantization_config
-)
+image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
```
+
+
+
+
+When there is enough memory, you can also directly move the pipeline to the GPU with `.to("cuda")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage.
+
+Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 8-bit models locally with [`~ModelMixin.save_pretrained`].
+
Quantizing a model in 4-bit reduces your memory-usage by 4x:
+bitsandbytes is supported in both Transformers and Diffusers, so you can can quantize both the
+[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].
+
+For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`.
+
+> [!TIP]
+> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers.
+
```py
-from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
+from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
-quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+from diffusers import FluxTransformer2DModel
+from transformers import T5EncoderModel
-model_4bit = FluxTransformer2DModel.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
+quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,)
+
+text_encoder_2_4bit = T5EncoderModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="text_encoder_2",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,)
+
+transformer_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
subfolder="transformer",
- quantization_config=quantization_config
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
)
```
-By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
+By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
-```py
-from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+```diff
+transformer_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quant_config,
++ torch_dtype=torch.float32,
+)
+```
-quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+Let's generate an image using our quantized models.
-model_4bit = FluxTransformer2DModel.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
- subfolder="transformer",
- quantization_config=quantization_config,
- torch_dtype=torch.float32
+Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
+
+```py
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ transformer=transformer_4bit,
+ text_encoder_2=text_encoder_2_4bit,
+ torch_dtype=torch.float16,
+ device_map="auto",
)
-model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
+
+pipe_kwargs = {
+ "prompt": "A cat holding a sign that says hello world",
+ "height": 1024,
+ "width": 1024,
+ "guidance_scale": 3.5,
+ "num_inference_steps": 50,
+ "max_sequence_length": 512,
+}
+
+image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
```
-Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
+
+
+
+
+When there is enough memory, you can also directly move the pipeline to the GPU with `.to("cuda")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage.
+
+Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
@@ -131,7 +222,7 @@ from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = FluxTransformer2DModel.from_pretrained(
- "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
+ "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
```
@@ -211,17 +302,34 @@ quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dty
NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]:
```py
-from diffusers import BitsAndBytesConfig
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
+from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
+
+from diffusers import FluxTransformer2DModel
+from transformers import T5EncoderModel
-nf4_config = BitsAndBytesConfig(
+quant_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
-model_nf4 = SD3Transformer2DModel.from_pretrained(
- "stabilityai/stable-diffusion-3-medium-diffusers",
+text_encoder_2_4bit = T5EncoderModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="text_encoder_2",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+)
+
+transformer_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
subfolder="transformer",
- quantization_config=nf4_config,
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
)
```
@@ -232,36 +340,77 @@ For inference, the `bnb_4bit_quant_type` does not have a huge impact on performa
Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.
```py
-from diffusers import BitsAndBytesConfig
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
+from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
+
+from diffusers import FluxTransformer2DModel
+from transformers import T5EncoderModel
-double_quant_config = BitsAndBytesConfig(
+quant_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
-double_quant_model = SD3Transformer2DModel.from_pretrained(
- "stabilityai/stable-diffusion-3-medium-diffusers",
+text_encoder_2_4bit = T5EncoderModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="text_encoder_2",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+)
+
+transformer_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
subfolder="transformer",
- quantization_config=double_quant_config,
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
)
```
## Dequantizing `bitsandbytes` models
-Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.
+Once quantized, you can dequantize a model to its original precision, but this might result in a small loss of quality. Make sure you have enough GPU RAM to fit the dequantized model.
```python
-from diffusers import BitsAndBytesConfig
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
+from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
-double_quant_config = BitsAndBytesConfig(
+from diffusers import FluxTransformer2DModel
+from transformers import T5EncoderModel
+
+quant_config = TransformersBitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
-double_quant_model = SD3Transformer2DModel.from_pretrained(
- "stabilityai/stable-diffusion-3-medium-diffusers",
+text_encoder_2_4bit = T5EncoderModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="text_encoder_2",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+)
+
+transformer_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
subfolder="transformer",
- quantization_config=double_quant_config,
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
)
-model.dequantize()
-```
\ No newline at end of file
+
+text_encoder_2_4bit.dequantize()
+transformer_4bit.dequantize()
+```
+
+## Resources
+
+* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
+* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)
\ No newline at end of file
diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md
new file mode 100644
index 000000000000..f7537d7e7882
--- /dev/null
+++ b/docs/source/en/quantization/gguf.md
@@ -0,0 +1,69 @@
+
+
+# GGUF
+
+The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.
+
+The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.
+
+Before starting please install gguf in your environment
+
+```shell
+pip install -U gguf
+```
+
+Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].
+
+When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.uint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
+
+The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original [`numpy`](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py) implementation by [compilade](https://github.com/compilade).
+
+```python
+import torch
+
+from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
+
+ckpt_path = (
+ "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
+)
+transformer = FluxTransformer2DModel.from_single_file(
+ ckpt_path,
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ torch_dtype=torch.bfloat16,
+)
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+)
+pipe.enable_model_cpu_offload()
+prompt = "A cat holding a sign that says hello world"
+image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
+image.save("flux-gguf.png")
+```
+
+## Supported Quantization Types
+
+- BF16
+- Q4_0
+- Q4_1
+- Q5_0
+- Q5_1
+- Q8_0
+- Q2_K
+- Q3_K
+- Q4_K
+- Q5_K
+- Q6_K
+
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
index d8adbc85a259..93323f86c7fc 100644
--- a/docs/source/en/quantization/overview.md
+++ b/docs/source/en/quantization/overview.md
@@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a
-Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
+Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
@@ -32,4 +32,10 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what?
-This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
\ No newline at end of file
+Diffusers currently supports the following quantization methods.
+- [BitsandBytes](./bitsandbytes)
+- [TorchAO](./torchao)
+- [GGUF](./gguf)
+- [Quanto](./quanto.md)
+
+[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
diff --git a/docs/source/en/quantization/quanto.md b/docs/source/en/quantization/quanto.md
new file mode 100644
index 000000000000..d322d76be267
--- /dev/null
+++ b/docs/source/en/quantization/quanto.md
@@ -0,0 +1,148 @@
+
+
+# Quanto
+
+[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
+
+- All features are available in eager mode (works with non-traceable models)
+- Supports quantization aware training
+- Quantized models are compatible with `torch.compile`
+- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
+
+In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
+
+```shell
+pip install optimum-quanto accelerate
+```
+
+Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
+
+```python
+import torch
+from diffusers import FluxTransformer2DModel, QuantoConfig
+
+model_id = "black-forest-labs/FLUX.1-dev"
+quantization_config = QuantoConfig(weights_dtype="float8")
+transformer = FluxTransformer2DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+)
+
+pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
+pipe.to("cuda")
+
+prompt = "A cat holding a sign that says hello world"
+image = pipe(
+ prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
+).images[0]
+image.save("output.png")
+```
+
+## Skipping Quantization on specific modules
+
+It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
+
+```python
+import torch
+from diffusers import FluxTransformer2DModel, QuantoConfig
+
+model_id = "black-forest-labs/FLUX.1-dev"
+quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
+transformer = FluxTransformer2DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+)
+```
+
+## Using `from_single_file` with the Quanto Backend
+
+`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
+
+```python
+import torch
+from diffusers import FluxTransformer2DModel, QuantoConfig
+
+ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
+quantization_config = QuantoConfig(weights_dtype="float8")
+transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
+```
+
+## Saving Quantized models
+
+Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
+
+The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
+with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
+
+```python
+import torch
+from diffusers import FluxTransformer2DModel, QuantoConfig
+
+model_id = "black-forest-labs/FLUX.1-dev"
+quantization_config = QuantoConfig(weights_dtype="float8")
+transformer = FluxTransformer2DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+)
+# save quantized model to reuse
+transformer.save_pretrained("")
+
+# you can reload your quantized model with
+model = FluxTransformer2DModel.from_pretrained("")
+```
+
+## Using `torch.compile` with Quanto
+
+Currently the Quanto backend supports `torch.compile` for the following quantization types:
+
+- `int8` weights
+
+```python
+import torch
+from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
+
+model_id = "black-forest-labs/FLUX.1-dev"
+quantization_config = QuantoConfig(weights_dtype="int8")
+transformer = FluxTransformer2DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+)
+transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
+
+pipe = FluxPipeline.from_pretrained(
+ model_id, transformer=transformer, torch_dtype=torch_dtype
+)
+pipe.to("cuda")
+images = pipe("A cat holding a sign that says hello").images[0]
+images.save("flux-quanto-compile.png")
+```
+
+## Supported Quantization Types
+
+### Weights
+
+- float8
+- int8
+- int4
+- int2
+
+
diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md
new file mode 100644
index 000000000000..19a8970fa9df
--- /dev/null
+++ b/docs/source/en/quantization/torchao.md
@@ -0,0 +1,156 @@
+
+
+# torchao
+
+[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
+
+Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
+
+```bash
+pip install -U torch torchao
+```
+
+
+Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
+
+The example below only quantizes the weights to int8.
+
+```python
+import torch
+from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
+
+model_id = "black-forest-labs/FLUX.1-dev"
+dtype = torch.bfloat16
+
+quantization_config = TorchAoConfig("int8wo")
+transformer = FluxTransformer2DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=dtype,
+)
+pipe = FluxPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=dtype,
+)
+pipe.to("cuda")
+
+# Without quantization: ~31.447 GB
+# With quantization: ~20.40 GB
+print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
+
+prompt = "A cat holding a sign that says hello world"
+image = pipe(
+ prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
+).images[0]
+image.save("output.png")
+```
+
+TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
+
+```python
+# In the above code, add the following after initializing the transformer
+transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
+```
+
+For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
+
+torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
+
+The `TorchAoConfig` class accepts three parameters:
+- `quant_type`: A string value mentioning one of the quantization types below.
+- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
+- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
+
+## Supported quantization types
+
+torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
+
+Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
+
+Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
+
+The quantization methods supported are as follows:
+
+| **Category** | **Full Function Names** | **Shorthands** |
+|--------------|-------------------------|----------------|
+| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
+| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
+| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
+| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
+
+Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
+
+Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
+
+## Serializing and Deserializing quantized models
+
+To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
+
+```python
+import torch
+from diffusers import FluxTransformer2DModel, TorchAoConfig
+
+quantization_config = TorchAoConfig("int8wo")
+transformer = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/Flux.1-Dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+)
+transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
+```
+
+To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
+
+```python
+import torch
+from diffusers import FluxPipeline, FluxTransformer2DModel
+
+transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
+pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+prompt = "A cat holding a sign that says hello world"
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
+image.save("output.png")
+```
+
+If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
+
+```python
+import torch
+from accelerate import init_empty_weights
+from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
+
+# Serialize the model
+transformer = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/Flux.1-Dev",
+ subfolder="transformer",
+ quantization_config=TorchAoConfig("uint4wo"),
+ torch_dtype=torch.bfloat16,
+)
+transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
+# ...
+
+# Load the model
+state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
+with init_empty_weights():
+ transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
+transformer.load_state_dict(state_dict, strict=True, assign=True)
+```
+
+## Resources
+
+- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
+- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
diff --git a/docs/source/en/training/create_dataset.md b/docs/source/en/training/create_dataset.md
index 38783eff76bd..f3221beb408f 100644
--- a/docs/source/en/training/create_dataset.md
+++ b/docs/source/en/training/create_dataset.md
@@ -1,6 +1,6 @@
# Create a dataset for training
-There are many datasets on the [Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) to train a model on, but if you can't find one you're interested in or want to use your own, you can create a dataset with the 🤗 [Datasets](hf.co/docs/datasets) library. The dataset structure depends on the task you want to train your model on. The most basic dataset structure is a directory of images for tasks like unconditional image generation. Another dataset structure may be a directory of images and a text file containing their corresponding text captions for tasks like text-to-image generation.
+There are many datasets on the [Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) to train a model on, but if you can't find one you're interested in or want to use your own, you can create a dataset with the 🤗 [Datasets](https://huggingface.co/docs/datasets) library. The dataset structure depends on the task you want to train your model on. The most basic dataset structure is a directory of images for tasks like unconditional image generation. Another dataset structure may be a directory of images and a text file containing their corresponding text captions for tasks like text-to-image generation.
This guide will show you two ways to create a dataset to finetune on:
@@ -87,4 +87,4 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
Now that you've created a dataset, you can plug it into the `train_data_dir` (if your dataset is local) or `dataset_name` (if your dataset is on the Hub) arguments of a training script.
-For your next steps, feel free to try and use your dataset to train a model for [unconditional generation](unconditional_training) or [text-to-image generation](text2image)!
\ No newline at end of file
+For your next steps, feel free to try and use your dataset to train a model for [unconditional generation](unconditional_training) or [text-to-image generation](text2image)!
diff --git a/docs/source/en/training/custom_diffusion.md b/docs/source/en/training/custom_diffusion.md
index 02fc319709eb..ce02ba843b17 100644
--- a/docs/source/en/training/custom_diffusion.md
+++ b/docs/source/en/training/custom_diffusion.md
@@ -339,7 +339,10 @@ import torch
from huggingface_hub.repocard import RepoCard
from diffusers import DiffusionPipeline
-pipeline = DiffusionPipeline.from_pretrained("sayakpaul/custom-diffusion-cat-wooden-pot", torch_dtype=torch.float16).to("cuda")
+pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16,
+).to("cuda")
+model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
pipeline.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
pipeline.load_textual_inversion(model_id, weight_name=".bin")
pipeline.load_textual_inversion(model_id, weight_name=".bin")
diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md
index 0e1eb7962bf7..79b4f785f30c 100644
--- a/docs/source/en/training/distributed_inference.md
+++ b/docs/source/en/training/distributed_inference.md
@@ -183,7 +183,7 @@ Add the transformer model to the pipeline for denoising, but set the other model
```py
pipeline = FluxPipeline.from_pretrained(
- "black-forest-labs/FLUX.1-dev", ,
+ "black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
diff --git a/docs/source/en/tutorials/basic_training.md b/docs/source/en/tutorials/basic_training.md
index 402c8c59b17d..f8c4a5b84b9f 100644
--- a/docs/source/en/tutorials/basic_training.md
+++ b/docs/source/en/tutorials/basic_training.md
@@ -75,7 +75,7 @@ For convenience, create a `TrainingConfig` class containing the training hyperpa
... push_to_hub = True # whether to upload the saved model to the HF Hub
... hub_model_id = "/" # the name of the repository to create on the HF Hub
-... hub_private_repo = False
+... hub_private_repo = None
... overwrite_output_dir = True # overwrite the old model when re-running the notebook
... seed = 0
diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index 615af55ef5b5..33414a331ea7 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -56,7 +56,7 @@ image
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
-The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method:
+The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method:
```python
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
@@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte
You can also merge different adapter checkpoints for inference to blend their styles together.
-Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
+Once again, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
```python
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
@@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte
> [!TIP]
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
-To return to only using one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
+To return to only using one adapter, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
```python
pipe.set_adapters("toy")
@@ -127,7 +127,7 @@ image = pipe(
image
```
-Or to disable all adapters entirely, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method to return the base model.
+Or to disable all adapters entirely, use the [`~loaders.peft.PeftAdapterMixin.disable_lora`] method to return the base model.
```python
pipe.disable_lora()
@@ -140,7 +140,8 @@ image

### Customize adapters strength
-For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
+
+For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~loaders.peft.PeftAdapterMixin.set_adapters`].
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
```python
@@ -195,7 +196,7 @@ image

-## Manage active adapters
+## Manage adapters
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.StableDiffusionLoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
@@ -212,3 +213,15 @@ list_adapters_component_wise = pipe.get_list_adapters()
list_adapters_component_wise
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
```
+
+The [`~loaders.peft.PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
+
+```py
+pipe.delete_adapters("toy")
+pipe.get_active_adapters()
+["pixel"]
+```
+
+## PeftInputAutocastDisableHook
+
+[[autodoc]] hooks.layerwise_casting.PeftInputAutocastDisableHook
diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md
index 68c621ffc50d..2462fed1a3cf 100644
--- a/docs/source/en/using-diffusers/callback.md
+++ b/docs/source/en/using-diffusers/callback.md
@@ -157,6 +157,84 @@ pipeline(
)
```
+## IP Adapter Cutoff
+
+IP Adapter is an image prompt adapter that can be used for diffusion models without any changes to the underlying model. We can use the IP Adapter Cutoff Callback to disable the IP Adapter after a certain number of steps. To set up the callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments:
+
+- `cutoff_step_ratio`: Float number with the ratio of the steps.
+- `cutoff_step_index`: Integer number with the exact number of the step.
+
+We need to download the diffusion model and load the ip_adapter for it as follows:
+
+```py
+from diffusers import AutoPipelineForText2Image
+from diffusers.utils import load_image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
+pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
+pipeline.set_ip_adapter_scale(0.6)
+```
+The setup for the callback should look something like this:
+
+```py
+
+from diffusers import AutoPipelineForText2Image
+from diffusers.callbacks import IPAdapterScaleCutoffCallback
+from diffusers.utils import load_image
+import torch
+
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+
+pipeline.load_ip_adapter(
+ "h94/IP-Adapter",
+ subfolder="sdxl_models",
+ weight_name="ip-adapter_sdxl.bin"
+)
+
+pipeline.set_ip_adapter_scale(0.6)
+
+
+callback = IPAdapterScaleCutoffCallback(
+ cutoff_step_ratio=None,
+ cutoff_step_index=5
+)
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png"
+)
+
+generator = torch.Generator(device="cuda").manual_seed(2628670641)
+
+images = pipeline(
+ prompt="a tiger sitting in a chair drinking orange juice",
+ ip_adapter_image=image,
+ negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
+ generator=generator,
+ num_inference_steps=50,
+ callback_on_step_end=callback,
+).images
+
+images[0].save("custom_callback_img.png")
+```
+
+
+
+
+
without IPAdapterScaleCutoffCallback
+
+
+
+
with IPAdapterScaleCutoffCallback
+
+
+
+
## Display image after each generation step
> [!TIP]
diff --git a/docs/source/en/using-diffusers/consisid.md b/docs/source/en/using-diffusers/consisid.md
new file mode 100644
index 000000000000..07c13c4c66b3
--- /dev/null
+++ b/docs/source/en/using-diffusers/consisid.md
@@ -0,0 +1,96 @@
+
+# ConsisID
+
+[ConsisID](https://github.com/PKU-YuanGroup/ConsisID) is an identity-preserving text-to-video generation model that keeps the face consistent in the generated video by frequency decomposition. The main features of ConsisID are:
+
+- Frequency decomposition: The characteristics of the DiT architecture are analyzed from the frequency domain perspective, and based on these characteristics, a reasonable control information injection method is designed.
+- Consistency training strategy: A coarse-to-fine training strategy, dynamic masking loss, and dynamic cross-face loss further enhance the model's generalization ability and identity preservation performance.
+- Inference without finetuning: Previous methods required case-by-case finetuning of the input ID before inference, leading to significant time and computational costs. In contrast, ConsisID is tuning-free.
+
+This guide will walk you through using ConsisID for use cases.
+
+## Load Model Checkpoints
+
+Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
+
+```python
+# !pip install consisid_eva_clip insightface facexlib
+import torch
+from diffusers import ConsisIDPipeline
+from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
+from huggingface_hub import snapshot_download
+
+# Download ckpts
+snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
+
+# Load face helper model to preprocess input face image
+face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
+
+# Load consisid base model
+pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+```
+
+## Identity-Preserving Text-to-Video
+
+For identity-preserving text-to-video, pass a text prompt and an image contain clear face (e.g., preferably half-body or full-body). By default, ConsisID generates a 720x480 video for the best results.
+
+```python
+from diffusers.utils import export_to_video
+
+prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."
+image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true"
+
+id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2, eva_transform_mean, eva_transform_std, face_main_model, "cuda", torch.bfloat16, image, is_align_face=True)
+
+video = pipe(image=image, prompt=prompt, num_inference_steps=50, guidance_scale=6.0, use_dynamic_cfg=False, id_vit_hidden=id_vit_hidden, id_cond=id_cond, kps_cond=face_kps, generator=torch.Generator("cuda").manual_seed(42))
+export_to_video(video.frames[0], "output.mp4", fps=8)
+```
+
+
+ Face Image
+ Video
+ Description
+
+
+
+ The video, in a beautifully crafted animated style, features a confident woman riding a horse through a lush forest clearing. Her expression is focused yet serene as she adjusts her wide-brimmed hat with a practiced hand. She wears a flowy bohemian dress, which moves gracefully with the rhythm of the horse, the fabric flowing fluidly in the animated motion. The dappled sunlight filters through the trees, casting soft, painterly patterns on the forest floor. Her posture is poised, showing both control and elegance as she guides the horse with ease. The animation's gentle, fluid style adds a dreamlike quality to the scene, with the woman’s calm demeanor and the peaceful surroundings evoking a sense of freedom and harmony.
+
+
+
+
+ The video, in a captivating animated style, shows a woman standing in the center of a snowy forest, her eyes narrowed in concentration as she extends her hand forward. She is dressed in a deep blue cloak, her breath visible in the cold air, which is rendered with soft, ethereal strokes. A faint smile plays on her lips as she summons a wisp of ice magic, watching with focus as the surrounding trees and ground begin to shimmer and freeze, covered in delicate ice crystals. The animation’s fluid motion brings the magic to life, with the frost spreading outward in intricate, sparkling patterns. The environment is painted with soft, watercolor-like hues, enhancing the magical, dreamlike atmosphere. The overall mood is serene yet powerful, with the quiet winter air amplifying the delicate beauty of the frozen scene.
+
+
+
+
+ The animation features a whimsical portrait of a balloon seller standing in a gentle breeze, captured with soft, hazy brushstrokes that evoke the feel of a serene spring day. His face is framed by a gentle smile, his eyes squinting slightly against the sun, while a few wisps of hair flutter in the wind. He is dressed in a light, pastel-colored shirt, and the balloons around him sway with the wind, adding a sense of playfulness to the scene. The background blurs softly, with hints of a vibrant market or park, enhancing the light-hearted, yet tender mood of the moment.
+
+
+
+
+ The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.
+
+
+
+
+ The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.
+
+
+
+## Resources
+
+Learn more about ConsisID with the following resources.
+- A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features.
+- The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details.
diff --git a/docs/source/en/using-diffusers/create_a_server.md b/docs/source/en/using-diffusers/create_a_server.md
new file mode 100644
index 000000000000..8ad0ed3cbe6a
--- /dev/null
+++ b/docs/source/en/using-diffusers/create_a_server.md
@@ -0,0 +1,61 @@
+
+# Create a server
+
+Diffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.
+
+This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.
+
+
+Start by navigating to the `examples/server` folder and installing all of the dependencies.
+
+```py
+pip install .
+pip install -f requirements.txt
+```
+
+Launch the server with the following command.
+
+```py
+python server.py
+```
+
+The server is accessed at http://localhost:8000. You can curl this model with the following command.
+```
+curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
+```
+
+If you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.
+
+```
+uv pip compile requirements.in -o requirements.txt
+```
+
+
+The server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.
+```py
+@app.post("/v1/images/generations")
+async def generate_image(image_input: TextToImageInput):
+ try:
+ loop = asyncio.get_event_loop()
+ scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
+ pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
+ generator = torch.Generator(device="cuda")
+ generator.manual_seed(random.randint(0, 10000000))
+ output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
+ logger.info(f"output: {output}")
+ image_url = save_image(output.images[0])
+ return {"data": [{"url": image_url}]}
+ except Exception as e:
+ if isinstance(e, HTTPException):
+ raise e
+ elif hasattr(e, 'message'):
+ raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
+ raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
+```
+The `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.
+```py
+output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
+```
+At this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.
+
+Another important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.
diff --git a/docs/source/en/using-diffusers/img2img.md b/docs/source/en/using-diffusers/img2img.md
index 4618731830df..d9902081fde5 100644
--- a/docs/source/en/using-diffusers/img2img.md
+++ b/docs/source/en/using-diffusers/img2img.md
@@ -461,12 +461,12 @@ Chain it to an upscaler pipeline to increase the image resolution:
from diffusers import StableDiffusionLatentUpscalePipeline
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
- "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+ "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, use_safetensors=True
)
upscaler.enable_model_cpu_offload()
upscaler.enable_xformers_memory_efficient_attention()
-image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
+image_2 = upscaler(prompt, image=image_1).images[0]
```
Finally, chain it to a super-resolution pipeline to further enhance the resolution:
diff --git a/docs/source/en/using-diffusers/loading.md b/docs/source/en/using-diffusers/loading.md
index a45667fdc464..d48004d7400c 100644
--- a/docs/source/en/using-diffusers/loading.md
+++ b/docs/source/en/using-diffusers/loading.md
@@ -95,6 +95,23 @@ Use the Space below to gauge a pipeline's memory requirements before you downloa
>
+### Specifying Component-Specific Data Types
+
+You can customize the data types for individual sub-models by passing a dictionary to the `torch_dtype` parameter. This allows you to load different components of a pipeline in different floating point precisions. For instance, if you want to load the transformer with `torch.bfloat16` and all other components with `torch.float16`, you can pass a dictionary mapping:
+
+```python
+from diffusers import HunyuanVideoPipeline
+import torch
+
+pipe = HunyuanVideoPipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo",
+ torch_dtype={'transformer': torch.bfloat16, 'default': torch.float16},
+)
+print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
+```
+
+If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.
+
### Local pipeline
To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.
diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md
index a25d452e5186..e16c1322e5d1 100644
--- a/docs/source/en/using-diffusers/loading_adapters.md
+++ b/docs/source/en/using-diffusers/loading_adapters.md
@@ -134,14 +134,16 @@ The [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method loads L
- the LoRA weights don't have separate identifiers for the UNet and text encoder
- the LoRA weights have separate identifiers for the UNet and text encoder
-But if you only need to load LoRA weights into the UNet, then you can use the [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Let's load the [jbilcke-hf/sdxl-cinematic-1](https://huggingface.co/jbilcke-hf/sdxl-cinematic-1) LoRA:
+To directly load (and save) a LoRA adapter at the *model-level*, use [`~PeftAdapterMixin.load_lora_adapter`], which builds and prepares the necessary model configuration for the adapter. Like [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`PeftAdapterMixin.load_lora_adapter`] can load LoRAs for both the UNet and text encoder. For example, if you're loading a LoRA for the UNet, [`PeftAdapterMixin.load_lora_adapter`] ignores the keys for the text encoder.
+
+Use the `weight_name` parameter to specify the specific weight file and the `prefix` parameter to filter for the appropriate state dicts (`"unet"` in this case) to load.
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-pipeline.unet.load_attn_procs("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors")
+pipeline.unet.load_lora_adapter("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", prefix="unet")
# use cnmt in the prompt to trigger the LoRA
prompt = "A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration"
@@ -153,6 +155,8 @@ image
+Save an adapter with [`~PeftAdapterMixin.save_lora_adapter`].
+
To unload the LoRA weights, use the [`~loaders.StableDiffusionLoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
```py
diff --git a/docs/source/en/using-diffusers/marigold_usage.md b/docs/source/en/using-diffusers/marigold_usage.md
index e9756b7f1c8e..b8e9a5838e8d 100644
--- a/docs/source/en/using-diffusers/marigold_usage.md
+++ b/docs/source/en/using-diffusers/marigold_usage.md
@@ -1,4 +1,6 @@
-
-# Marigold Pipelines for Computer Vision Tasks
+# Marigold Computer Vision
-[Marigold](../api/pipelines/marigold) is a novel diffusion-based dense prediction approach, and a set of pipelines for various computer vision tasks, such as monocular depth estimation.
+**Marigold** is a diffusion-based [method](https://huggingface.co/papers/2312.02145) and a collection of [pipelines](../api/pipelines/marigold) designed for
+dense computer vision tasks, including **monocular depth prediction**, **surface normals estimation**, and **intrinsic
+image decomposition**.
-This guide will show you how to use Marigold to obtain fast and high-quality predictions for images and videos.
+This guide will walk you through using Marigold to generate fast and high-quality predictions for images and videos.
-Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image.
-Currently, the following tasks are implemented:
+Each pipeline is tailored for a specific computer vision task, processing an input RGB image and generating a
+corresponding prediction.
+Currently, the following computer vision tasks are implemented:
-| Pipeline | Predicted Modalities | Demos |
-|---------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------:|
-| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-lcm), [Slow Original Demo (DDIM)](https://huggingface.co/spaces/prs-eth/marigold) |
-| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-normals-lcm) |
+| Pipeline | Recommended Model Checkpoints | Spaces (Interactive Apps) | Predicted Modalities |
+|---------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | [Depth Estimation](https://huggingface.co/spaces/prs-eth/marigold) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) |
+| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [prs-eth/marigold-normals-v1-1](https://huggingface.co/prs-eth/marigold-normals-v1-1) | [Surface Normals Estimation](https://huggingface.co/spaces/prs-eth/marigold-normals) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) |
+| [MarigoldIntrinsicsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py) | [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1), [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | [Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid) | [Albedo](https://en.wikipedia.org/wiki/Albedo), [Materials](https://www.n.aiq3d.com/wiki/roughnessmetalnessao-map), [Lighting](https://en.wikipedia.org/wiki/Diffuse_reflection) |
-The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization.
-These checkpoints are meant to work with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold).
-The original code can also be used to train new checkpoints.
+All original checkpoints are available under the [PRS-ETH](https://huggingface.co/prs-eth/) organization on Hugging Face.
+They are designed for use with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold), which can also be used to train
+new model checkpoints.
+The following is a summary of the recommended checkpoints, all of which produce reliable results with 1 to 4 steps.
-| Checkpoint | Modality | Comment |
-|-----------------------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| [prs-eth/marigold-v1-0](https://huggingface.co/prs-eth/marigold-v1-0) | Depth | The first Marigold Depth checkpoint, which predicts *affine-invariant depth* maps. The performance of this checkpoint in benchmarks was studied in the original [paper](https://huggingface.co/papers/2312.02145). Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. Affine-invariant depth prediction has a range of values in each pixel between 0 (near plane) and 1 (far plane); both planes are chosen by the model as part of the inference process. See the `MarigoldImageProcessor` reference for visualization utilities. |
-| [prs-eth/marigold-depth-lcm-v1-0](https://huggingface.co/prs-eth/marigold-depth-lcm-v1-0) | Depth | The fast Marigold Depth checkpoint, fine-tuned from `prs-eth/marigold-v1-0`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. |
-| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | A preview checkpoint for the Marigold Normals pipeline. Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. The surface normals predictions are unit-length 3D vectors with values in the range from -1 to 1. *This checkpoint will be phased out after the release of `v1-0` version.* |
-| [prs-eth/marigold-normals-lcm-v0-1](https://huggingface.co/prs-eth/marigold-normals-lcm-v0-1) | Normals | The fast Marigold Normals checkpoint, fine-tuned from `prs-eth/marigold-normals-v0-1`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. *This checkpoint will be phased out after the release of `v1-0` version.* |
-The examples below are mostly given for depth prediction, but they can be universally applied with other supported modalities.
+| Checkpoint | Modality | Comment |
+|-----------------------------------------------------------------------------------------------------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | Depth | Affine-invariant depth prediction assigns each pixel a value between 0 (near plane) and 1 (far plane), with both planes determined by the model during inference. |
+| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | The surface normals predictions are unit-length 3D vectors in the screen space camera, with values in the range from -1 to 1. |
+| [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1) | Intrinsics | InteriorVerse decomposition is comprised of Albedo and two BRDF material properties: Roughness and Metallicity. |
+| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | Intrinsics | HyperSim decomposition of an image \\(I\\) is comprised of Albedo \\(A\\), Diffuse shading \\(S\\), and Non-diffuse residual \\(R\\): \\(I = A*S+R\\). |
+
+The examples below are mostly given for depth prediction, but they can be universally applied to other supported
+modalities.
We showcase the predictions using the same input image of Albert Einstein generated by Midjourney.
This makes it easier to compare visualizations of the predictions across various modalities and checkpoints.
@@ -47,19 +56,21 @@ This makes it easier to compare visualizations of the predictions across various
-### Depth Prediction Quick Start
+## Depth Prediction
-To get the first depth prediction, load `prs-eth/marigold-depth-lcm-v1-0` checkpoint into `MarigoldDepthPipeline` pipeline, put the image through the pipeline, and save the predictions:
+To get a depth prediction, load the `prs-eth/marigold-depth-v1-1` checkpoint into [`MarigoldDepthPipeline`],
+put the image through the pipeline, and save the predictions:
```python
import diffusers
import torch
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
- "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
+ "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
).to("cuda")
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+
depth = pipe(image)
vis = pipe.image_processor.visualize_depth(depth.prediction)
@@ -69,10 +80,13 @@ depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)
depth_16bit[0].save("einstein_depth_16bit.png")
```
-The visualization function for depth [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] applies one of [matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]` depth range into an RGB image.
-With the `Spectral` colormap, pixels with near depth are painted red, and far pixels are assigned blue color.
+The [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] function applies one of
+[matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]`
+depth range into an RGB image.
+With the `Spectral` colormap, pixels with near depth are painted red, and far pixels are blue.
The 16-bit PNG file stores the single channel values mapped linearly from the `[0, 1]` range into `[0, 65535]`.
-Below are the raw and the visualized predictions; as can be seen, dark areas (mustache) are easier to distinguish in the visualization:
+Below are the raw and the visualized predictions. The darker and closer areas (mustache) are easier to distinguish in
+the visualization.
@@ -89,28 +103,33 @@ Below are the raw and the visualized predictions; as can be seen, dark areas (mu
-### Surface Normals Prediction Quick Start
+## Surface Normals Estimation
-Load `prs-eth/marigold-normals-lcm-v0-1` checkpoint into `MarigoldNormalsPipeline` pipeline, put the image through the pipeline, and save the predictions:
+Load the `prs-eth/marigold-normals-v1-1` checkpoint into [`MarigoldNormalsPipeline`], put the image through the
+pipeline, and save the predictions:
```python
import diffusers
import torch
pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
- "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
+ "prs-eth/marigold-normals-v1-1", variant="fp16", torch_dtype=torch.float16
).to("cuda")
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+
normals = pipe(image)
vis = pipe.image_processor.visualize_normals(normals.prediction)
vis[0].save("einstein_normals.png")
```
-The visualization function for normals [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional prediction with pixel values in the range `[-1, 1]` into an RGB image.
-The visualization function supports flipping surface normals axes to make the visualization compatible with other choices of the frame of reference.
-Conceptually, each pixel is painted according to the surface normal vector in the frame of reference, where `X` axis points right, `Y` axis points up, and `Z` axis points at the viewer.
+The [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional
+prediction with pixel values in the range `[-1, 1]` into an RGB image.
+The visualization function supports flipping surface normals axes to make the visualization compatible with other
+choices of the frame of reference.
+Conceptually, each pixel is painted according to the surface normal vector in the frame of reference, where `X` axis
+points right, `Y` axis points up, and `Z` axis points at the viewer.
Below is the visualized prediction:
@@ -122,208 +141,226 @@ Below is the visualized prediction:
-In this example, the nose tip almost certainly has a point on the surface, in which the surface normal vector points straight at the viewer, meaning that its coordinates are `[0, 0, 1]`.
+In this example, the nose tip almost certainly has a point on the surface, in which the surface normal vector points
+straight at the viewer, meaning that its coordinates are `[0, 0, 1]`.
This vector maps to the RGB `[128, 128, 255]`, which corresponds to the violet-blue color.
-Similarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the red hue.
+Similarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the
+red hue.
Points on the shoulders pointing up with a large `Y` promote green color.
-### Speeding up inference
+## Intrinsic Image Decomposition
-The above quick start snippets are already optimized for speed: they load the LCM checkpoint, use the `fp16` variant of weights and computation, and perform just one denoising diffusion step.
-The `pipe(image)` call completes in 280ms on RTX 3090 GPU.
-Internally, the input image is encoded with the Stable Diffusion VAE encoder, then the U-Net performs one denoising step, and finally, the prediction latent is decoded with the VAE decoder into pixel space.
-In this case, two out of three module calls are dedicated to converting between pixel and latent space of LDM.
-Because Marigold's latent space is compatible with the base Stable Diffusion, it is possible to speed up the pipeline call by more than 3x (85ms on RTX 3090) by using a [lightweight replacement of the SD VAE](../api/models/autoencoder_tiny):
+Marigold provides two models for Intrinsic Image Decomposition (IID): "Appearance" and "Lighting".
+Each model produces Albedo maps, derived from InteriorVerse and Hypersim annotations, respectively.
-```diff
- import diffusers
- import torch
+- The "Appearance" model also estimates Material properties: Roughness and Metallicity.
+- The "Lighting" model generates Diffuse Shading and Non-diffuse Residual.
- pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
- "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
- ).to("cuda")
+Here is the sample code saving predictions made by the "Appearance" model:
-+ pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
-+ "madebyollin/taesd", torch_dtype=torch.float16
-+ ).cuda()
+```python
+import diffusers
+import torch
- image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
- depth = pipe(image)
+pipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained(
+ "prs-eth/marigold-iid-appearance-v1-1", variant="fp16", torch_dtype=torch.float16
+).to("cuda")
+
+image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+
+intrinsics = pipe(image)
+
+vis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties)
+vis[0]["albedo"].save("einstein_albedo.png")
+vis[0]["roughness"].save("einstein_roughness.png")
+vis[0]["metallicity"].save("einstein_metallicity.png")
```
-As suggested in [Optimizations](../optimization/torch2.0#torch.compile), adding `torch.compile` may squeeze extra performance depending on the target hardware:
+Another example demonstrating the predictions made by the "Lighting" model:
-```diff
- import diffusers
- import torch
+```python
+import diffusers
+import torch
- pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
- "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
- ).to("cuda")
+pipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained(
+ "prs-eth/marigold-iid-lighting-v1-1", variant="fp16", torch_dtype=torch.float16
+).to("cuda")
-+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
- image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
- depth = pipe(image)
+intrinsics = pipe(image)
+
+vis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties)
+vis[0]["albedo"].save("einstein_albedo.png")
+vis[0]["shading"].save("einstein_shading.png")
+vis[0]["residual"].save("einstein_residual.png")
```
-## Qualitative Comparison with Depth Anything
+Both models share the same pipeline while supporting different decomposition types.
+The exact decomposition parameterization (e.g., sRGB vs. linear space) is stored in the
+`pipe.target_properties` dictionary, which is passed into the
+[`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_intrinsics`] function.
-With the above speed optimizations, Marigold delivers predictions with more details and faster than [Depth Anything](https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything) with the largest checkpoint [LiheYoung/depth-anything-large-hf](https://huggingface.co/LiheYoung/depth-anything-large-hf):
+Below are some examples showcasing the predicted decomposition outputs.
+All modalities can be inspected in the
+[Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid) Space.
-
+
- Marigold LCM fp16 with Tiny AutoEncoder
+ Predicted albedo ("Appearance" model)
-
+
- Depth Anything Large
+ Predicted diffuse shading ("Lighting" model)
-## Maximizing Precision and Ensembling
+## Speeding up inference
-Marigold pipelines have a built-in ensembling mechanism combining multiple predictions from different random latents.
-This is a brute-force way of improving the precision of predictions, capitalizing on the generative nature of diffusion.
-The ensembling path is activated automatically when the `ensemble_size` argument is set greater than `1`.
-When aiming for maximum precision, it makes sense to adjust `num_inference_steps` simultaneously with `ensemble_size`.
-The recommended values vary across checkpoints but primarily depend on the scheduler type.
-The effect of ensembling is particularly well-seen with surface normals:
+The above quick start snippets are already optimized for quality and speed, loading the checkpoint, utilizing the
+`fp16` variant of weights and computation, and performing the default number (4) of denoising diffusion steps.
+The first step to accelerate inference, at the expense of prediction quality, is to reduce the denoising diffusion
+steps to the minimum:
-```python
-import diffusers
+```diff
+ import diffusers
+ import torch
-model_path = "prs-eth/marigold-normals-v1-0"
+ pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
+ "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
+ ).to("cuda")
-model_paper_kwargs = {
- diffusers.schedulers.DDIMScheduler: {
- "num_inference_steps": 10,
- "ensemble_size": 10,
- },
- diffusers.schedulers.LCMScheduler: {
- "num_inference_steps": 4,
- "ensemble_size": 5,
- },
-}
+ image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+
+- depth = pipe(image)
++ depth = pipe(image, num_inference_steps=1)
+```
-image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+With this change, the `pipe` call completes in 280ms on RTX 3090 GPU.
+Internally, the input image is first encoded using the Stable Diffusion VAE encoder, followed by a single denoising
+step performed by the U-Net.
+Finally, the prediction latent is decoded with the VAE decoder into pixel space.
+In this setup, two out of three module calls are dedicated to converting between the pixel and latent spaces of the LDM.
+Since Marigold's latent space is compatible with Stable Diffusion 2.0, inference can be accelerated by more than 3x,
+reducing the call time to 85ms on an RTX 3090, by using a [lightweight replacement of the SD VAE](../api/models/autoencoder_tiny).
+Note that using a lightweight VAE may slightly reduce the visual quality of the predictions.
-pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(model_path).to("cuda")
-pipe_kwargs = model_paper_kwargs[type(pipe.scheduler)]
+```diff
+ import diffusers
+ import torch
-depth = pipe(image, **pipe_kwargs)
+ pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
+ "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
+ ).to("cuda")
-vis = pipe.image_processor.visualize_normals(depth.prediction)
-vis[0].save("einstein_normals.png")
++ pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
++ "madebyollin/taesd", torch_dtype=torch.float16
++ ).cuda()
+
+ image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+
+ depth = pipe(image, num_inference_steps=1)
```
-
-
-
-
- Surface normals, no ensembling
-
-
-
-
-
- Surface normals, with ensembling
-
-
-
+So far, we have optimized the number of diffusion steps and model components. Self-attention operations account for a
+significant portion of computations.
+Speeding them up can be achieved by using a more efficient attention processor:
-As can be seen, all areas with fine-grained structurers, such as hair, got more conservative and on average more correct predictions.
-Such a result is more suitable for precision-sensitive downstream tasks, such as 3D reconstruction.
+```diff
+ import diffusers
+ import torch
++ from diffusers.models.attention_processor import AttnProcessor2_0
-## Quantitative Evaluation
+ pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
+ "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
+ ).to("cuda")
-To evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets), follow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values for `num_inference_steps` and `ensemble_size`.
-Optionally seed randomness to ensure reproducibility. Maximizing `batch_size` will deliver maximum device utilization.
++ pipe.vae.set_attn_processor(AttnProcessor2_0())
++ pipe.unet.set_attn_processor(AttnProcessor2_0())
-```python
-import diffusers
-import torch
+ image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
-device = "cuda"
-seed = 2024
-model_path = "prs-eth/marigold-v1-0"
-
-model_paper_kwargs = {
- diffusers.schedulers.DDIMScheduler: {
- "num_inference_steps": 50,
- "ensemble_size": 10,
- },
- diffusers.schedulers.LCMScheduler: {
- "num_inference_steps": 4,
- "ensemble_size": 10,
- },
-}
+ depth = pipe(image, num_inference_steps=1)
+```
-image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+Finally, as suggested in [Optimizations](../optimization/torch2.0#torch.compile), enabling `torch.compile` can further enhance performance depending on
+the target hardware.
+However, compilation incurs a significant overhead during the first pipeline invocation, making it beneficial only when
+the same pipeline instance is called repeatedly, such as within a loop.
-generator = torch.Generator(device=device).manual_seed(seed)
-pipe = diffusers.MarigoldDepthPipeline.from_pretrained(model_path).to(device)
-pipe_kwargs = model_paper_kwargs[type(pipe.scheduler)]
+```diff
+ import diffusers
+ import torch
+ from diffusers.models.attention_processor import AttnProcessor2_0
+
+ pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
+ "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
+ ).to("cuda")
-depth = pipe(image, generator=generator, **pipe_kwargs)
+ pipe.vae.set_attn_processor(AttnProcessor2_0())
+ pipe.unet.set_attn_processor(AttnProcessor2_0())
-# evaluate metrics
++ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
++ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+ image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+
+ depth = pipe(image, num_inference_steps=1)
```
-## Using Predictive Uncertainty
+## Maximizing Precision and Ensembling
-The ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random latents.
-As a side effect, it can be used to quantify epistemic (model) uncertainty; simply specify `ensemble_size` greater than 1 and set `output_uncertainty=True`.
-The resulting uncertainty will be available in the `uncertainty` field of the output.
-It can be visualized as follows:
+Marigold pipelines have a built-in ensembling mechanism combining multiple predictions from different random latents.
+This is a brute-force way of improving the precision of predictions, capitalizing on the generative nature of diffusion.
+The ensembling path is activated automatically when the `ensemble_size` argument is set greater or equal than `3`.
+When aiming for maximum precision, it makes sense to adjust `num_inference_steps` simultaneously with `ensemble_size`.
+The recommended values vary across checkpoints but primarily depend on the scheduler type.
+The effect of ensembling is particularly well-seen with surface normals:
-```python
-import diffusers
-import torch
+```diff
+ import diffusers
-pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
- "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
-).to("cuda")
+ pipe = diffusers.MarigoldNormalsPipeline.from_pretrained("prs-eth/marigold-normals-v1-1").to("cuda")
-image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
-depth = pipe(
- image,
- ensemble_size=10, # any number greater than 1; higher values yield higher precision
- output_uncertainty=True,
-)
+ image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
-uncertainty = pipe.image_processor.visualize_uncertainty(depth.uncertainty)
-uncertainty[0].save("einstein_depth_uncertainty.png")
+- depth = pipe(image)
++ depth = pipe(image, num_inference_steps=10, ensemble_size=5)
+
+ vis = pipe.image_processor.visualize_normals(depth.prediction)
+ vis[0].save("einstein_normals.png")
```
-
+
- Depth uncertainty
+ Surface normals, no ensembling
-
+
- Surface normals uncertainty
+ Surface normals, with ensembling
-The interpretation of uncertainty is easy: higher values (white) correspond to pixels, where the model struggles to make consistent predictions.
-Evidently, the depth model is the least confident around edges with discontinuity, where the object depth changes drastically.
-The surface normals model is the least confident in fine-grained structures, such as hair, and dark areas, such as the collar.
+As can be seen, all areas with fine-grained structurers, such as hair, got more conservative and on average more
+correct predictions.
+Such a result is more suitable for precision-sensitive downstream tasks, such as 3D reconstruction.
## Frame-by-frame Video Processing with Temporal Consistency
-Due to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent initialization.
-This becomes an obvious drawback compared to traditional end-to-end dense regression networks, as exemplified in the following videos:
+Due to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent
+initialization.
+This becomes an obvious drawback compared to traditional end-to-end dense regression networks, as exemplified in the
+following videos:
@@ -336,26 +373,32 @@ This becomes an obvious drawback compared to traditional end-to-end dense regres
-To address this issue, it is possible to pass `latents` argument to the pipelines, which defines the starting point of diffusion.
-Empirically, we found that a convex combination of the very same starting point noise latent and the latent corresponding to the previous frame prediction give sufficiently smooth results, as implemented in the snippet below:
+To address this issue, it is possible to pass `latents` argument to the pipelines, which defines the starting point of
+diffusion.
+Empirically, we found that a convex combination of the very same starting point noise latent and the latent
+corresponding to the previous frame prediction give sufficiently smooth results, as implemented in the snippet below:
```python
import imageio
-from PIL import Image
-from tqdm import tqdm
import diffusers
import torch
+from diffusers.models.attention_processor import AttnProcessor2_0
+from PIL import Image
+from tqdm import tqdm
device = "cuda"
-path_in = "obama.mp4"
+path_in = "https://huggingface.co/spaces/prs-eth/marigold-lcm/resolve/c7adb5427947d2680944f898cd91d386bf0d4924/files/video/obama.mp4"
path_out = "obama_depth.gif"
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
- "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
+ "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
).to(device)
pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
"madebyollin/taesd", torch_dtype=torch.float16
).to(device)
+pipe.unet.set_attn_processor(AttnProcessor2_0())
+pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
+pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.set_progress_bar_config(disable=True)
with imageio.get_reader(path_in) as reader:
@@ -373,7 +416,11 @@ with imageio.get_reader(path_in) as reader:
latents = 0.9 * latents + 0.1 * last_frame_latent
depth = pipe(
- frame, match_input_resolution=False, latents=latents, output_latent=True
+ frame,
+ num_inference_steps=1,
+ match_input_resolution=False,
+ latents=latents,
+ output_latent=True,
)
last_frame_latent = depth.latent
out.append(pipe.image_processor.visualize_depth(depth.prediction)[0])
@@ -382,7 +429,8 @@ with imageio.get_reader(path_in) as reader:
```
Here, the diffusion process starts from the given computed latent.
-The pipeline sets `output_latent=True` to access `out.latent` and computes its contribution to the next frame's latent initialization.
+The pipeline sets `output_latent=True` to access `out.latent` and computes its contribution to the next frame's latent
+initialization.
The result is much more stable now:
@@ -414,7 +462,7 @@ image = diffusers.utils.load_image(
)
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
- "prs-eth/marigold-depth-lcm-v1-0", torch_dtype=torch.float16, variant="fp16"
+ "prs-eth/marigold-depth-v1-1", torch_dtype=torch.float16, variant="fp16"
).to(device)
depth_image = pipe(image, generator=generator).prediction
@@ -463,4 +511,95 @@ controlnet_out[0].save("motorcycle_controlnet_out.png")
-Hopefully, you will find Marigold useful for solving your downstream tasks, be it a part of a more broad generative workflow, or a perception task, such as 3D reconstruction.
+## Quantitative Evaluation
+
+To evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets),
+follow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values
+for `num_inference_steps` and `ensemble_size`.
+Optionally seed randomness to ensure reproducibility.
+Maximizing `batch_size` will deliver maximum device utilization.
+
+```python
+import diffusers
+import torch
+
+device = "cuda"
+seed = 2024
+
+generator = torch.Generator(device=device).manual_seed(seed)
+pipe = diffusers.MarigoldDepthPipeline.from_pretrained("prs-eth/marigold-depth-v1-1").to(device)
+
+image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+
+depth = pipe(
+ image,
+ num_inference_steps=4, # set according to the evaluation protocol from the paper
+ ensemble_size=10, # set according to the evaluation protocol from the paper
+ generator=generator,
+)
+
+# evaluate metrics
+```
+
+## Using Predictive Uncertainty
+
+The ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random
+latents.
+As a side effect, it can be used to quantify epistemic (model) uncertainty; simply specify `ensemble_size` greater
+or equal than 3 and set `output_uncertainty=True`.
+The resulting uncertainty will be available in the `uncertainty` field of the output.
+It can be visualized as follows:
+
+```python
+import diffusers
+import torch
+
+pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
+ "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
+).to("cuda")
+
+image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+
+depth = pipe(
+ image,
+ ensemble_size=10, # any number >= 3
+ output_uncertainty=True,
+)
+
+uncertainty = pipe.image_processor.visualize_uncertainty(depth.uncertainty)
+uncertainty[0].save("einstein_depth_uncertainty.png")
+```
+
+
+
+
+
+ Depth uncertainty
+
+
+
+
+
+ Surface normals uncertainty
+
+
+
+
+
+ Albedo uncertainty
+
+
+
+
+The interpretation of uncertainty is easy: higher values (white) correspond to pixels, where the model struggles to
+make consistent predictions.
+- The depth model exhibits the most uncertainty around discontinuities, where object depth changes abruptly.
+- The surface normals model is least confident in fine-grained structures like hair and in dark regions such as the
+collar area.
+- Albedo uncertainty is represented as an RGB image, as it captures uncertainty independently for each color channel,
+unlike depth and surface normals. It is also higher in shaded regions and at discontinuities.
+
+## Conclusion
+
+We hope Marigold proves valuable for your downstream tasks, whether as part of a broader generative workflow or for
+perception-based applications like 3D reconstruction.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md
new file mode 100644
index 000000000000..40a9e81bcd52
--- /dev/null
+++ b/docs/source/en/using-diffusers/omnigen.md
@@ -0,0 +1,317 @@
+
+# OmniGen
+
+OmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is a single model designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation). It has the following features:
+- Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images.
+- Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text.
+
+For more information, please refer to the [paper](https://arxiv.org/pdf/2409.11340).
+This guide will walk you through using OmniGen for various tasks and use cases.
+
+## Load model checkpoints
+
+Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
+
+```python
+import torch
+from diffusers import OmniGenPipeline
+
+pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
+```
+
+## Text-to-image
+
+For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
+You can try setting the `height` and `width` parameters to generate images with different size.
+
+```python
+import torch
+from diffusers import OmniGenPipeline
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
+image = pipe(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=3,
+ generator=torch.Generator(device="cpu").manual_seed(111),
+).images[0]
+image.save("output.png")
+```
+
+
+
+
+
+## Image edit
+
+OmniGen supports multimodal inputs.
+When the input includes an image, you need to add a placeholder ` <|image_1|>` in the text prompt to represent the image.
+It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
+
+```python
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt=" <|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(222)
+).images[0]
+image.save("output.png")
+```
+
+
+
+
+
original image
+
+
+
+
edited image
+
+
+
+OmniGen has some interesting features, such as visual reasoning, as shown in the example below.
+
+```python
+prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <|image_1|>"
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(0)
+).images[0]
+image.save("output.png")
+```
+
+
+
+
+
+## Controllable generation
+
+OmniGen can handle several classic computer vision tasks. As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
+
+```python
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="Detect the skeleton of human in this image: <|image_1|>"
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image1 = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(333)
+).images[0]
+image1.save("image1.png")
+
+prompt="Generate a new photo using the following picture and text as conditions: <|image_1|>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")]
+image2 = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(333)
+).images[0]
+image2.save("image2.png")
+```
+
+
+
+
+
original image
+
+
+
+
detected skeleton
+
+
+
+
skeleton to image
+
+
+
+
+OmniGen can also directly use relevant information from input images to generate new images.
+
+```python
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="Following the pose of this image <|image_1|>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(0)
+).images[0]
+image.save("output.png")
+```
+
+
+
+
+
generated image
+
+
+
+## ID and object preserving
+
+OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously.
+Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.
+
+```python
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <|image_1|>. The woman is the woman on the left of <|image_2|>"
+input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png")
+input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png")
+input_images=[input_image_1, input_image_2]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ height=1024,
+ width=1024,
+ guidance_scale=2.5,
+ img_guidance_scale=1.6,
+ generator=torch.Generator(device="cpu").manual_seed(666)
+).images[0]
+image.save("output.png")
+```
+
+
+
+
+
input_image_1
+
+
+
+
input_image_2
+
+
+
+
generated image
+
+
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <|image_1|>. The long-sleeve blouse and a pleated skirt are <|image_2|>."
+input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg")
+input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg")
+input_images=[input_image_1, input_image_2]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ height=1024,
+ width=1024,
+ guidance_scale=2.5,
+ img_guidance_scale=1.6,
+ generator=torch.Generator(device="cpu").manual_seed(666)
+).images[0]
+image.save("output.png")
+```
+
+
+
+
+
person image
+
+
+
+
clothe image
+
+
+
+
generated image
+
+
+
+## Optimization when using multiple images
+
+For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU).
+However, when using input images, the computational cost increases.
+
+Here are some guidelines to help you reduce computational costs when using multiple images. The experiments are conducted on an A800 GPU with two input images.
+
+Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `.
+In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`.
+The memory consumption for different image sizes is shown in the table below:
+
+| Method | Memory Usage |
+|---------------------------|--------------|
+| max_input_image_size=1024 | 40GB |
+| max_input_image_size=512 | 17GB |
+| max_input_image_size=256 | 14GB |
+
diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md
index 24ac9ced84ce..e662e3940a38 100644
--- a/docs/source/en/using-diffusers/other-formats.md
+++ b/docs/source/en/using-diffusers/other-formats.md
@@ -240,6 +240,46 @@ Benefits of using a single-file layout include:
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
2. Easier to manage (download and share) a single file.
+### DDUF
+
+> [!WARNING]
+> DDUF is an experimental file format and APIs related to it can change in the future.
+
+DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.
+
+Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).
+
+Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
+).to("cuda")
+image = pipe(
+ "photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
+).images[0]
+image.save("cat.png")
+```
+
+To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.
+
+```py
+from huggingface_hub import export_folder_as_dduf
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
+
+save_folder = "flux-dev"
+pipe.save_pretrained("flux-dev")
+export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)
+
+> [!TIP]
+> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.
+
## Convert layout and files
Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
diff --git a/docs/source/en/using-diffusers/text-img2vid.md b/docs/source/en/using-diffusers/text-img2vid.md
index 8dcc73a3c81c..92e740bb579d 100644
--- a/docs/source/en/using-diffusers/text-img2vid.md
+++ b/docs/source/en/using-diffusers/text-img2vid.md
@@ -1,4 +1,4 @@
-
-# Text or image-to-video
+# Video generation
-Driven by the success of text-to-image diffusion models, generative video models are able to generate short clips of video from a text prompt or an initial image. These models extend a pretrained diffusion model to generate videos by adding some type of temporal and/or spatial convolution layer to the architecture. A mixed dataset of images and videos are used to train the model which learns to output a series of video frames based on the text or image conditioning.
+Video generation models include a temporal dimension to bring images, or frames, together to create a video. These models are trained on large-scale datasets of high-quality text-video pairs to learn how to combine the modalities to ensure the generated video is coherent and realistic.
-This guide will show you how to generate videos, how to configure video model parameters, and how to control video generation.
+[Explore](https://huggingface.co/models?other=video-generation) some of the more popular open-source video generation models available from Diffusers below.
-## Popular models
+
+
-> [!TIP]
-> Discover other cool and trending video generation models on the Hub [here](https://huggingface.co/models?pipeline_tag=text-to-video&sort=trending)!
-
-[Stable Video Diffusions (SVD)](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid), [I2VGen-XL](https://huggingface.co/ali-vilab/i2vgen-xl/), [AnimateDiff](https://huggingface.co/guoyww/animatediff), and [ModelScopeT2V](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b) are popular models used for video diffusion. Each model is distinct. For example, AnimateDiff inserts a motion modeling module into a frozen text-to-image model to generate personalized animated images, whereas SVD is entirely pretrained from scratch with a three-stage training process to generate short high-quality videos.
+[CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) uses a 3D causal Variational Autoencoder (VAE) to compress videos along the spatial and temporal dimensions, and it includes a stack of expert transformer blocks with a 3D full attention mechanism to better capture visual, semantic, and motion information in the data.
-[CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) is another popular video generation model. The model is a multidimensional transformer that integrates text, time, and space. It employs full attention in the attention module and includes an expert block at the layer level to spatially align text and video.
+The CogVideoX family also includes models capable of generating videos from images and videos in addition to text. The image-to-video models are indicated by **I2V** in the checkpoint name, and they should be used with the [`CogVideoXImageToVideoPipeline`]. The regular checkpoints support video-to-video through the [`CogVideoXVideoToVideoPipeline`].
-### CogVideoX
-
-[CogVideoX](../api/pipelines/cogvideox) uses a 3D Variational Autoencoder (VAE) to compress videos along the spatial and temporal dimensions.
-
-Begin by loading the [`CogVideoXPipeline`] and passing an initial text or image to generate a video.
-
-
-CogVideoX is available for image-to-video and text-to-video. [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) uses the [`CogVideoXImageToVideoPipeline`] for image-to-video. [THUDM/CogVideoX-5b](https://huggingface.co/THUDM/CogVideoX-5b) and [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b) are available for text-to-video with the [`CogVideoXPipeline`].
-
-
+The example below demonstrates how to generate a video from an image and text prompt with [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V).
```py
import torch
@@ -42,12 +31,13 @@ from diffusers import CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
-image = load_image(image="cogvideox_rocket.png")
+image = load_image(image="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_rocket.png")
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
"THUDM/CogVideoX-5b-I2V",
torch_dtype=torch.bfloat16
)
-
+
+# reduce memory requirements
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
@@ -60,7 +50,6 @@ video = pipe(
guidance_scale=6,
generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]
-
export_to_video(video, "output.mp4", fps=8)
```
@@ -75,90 +64,141 @@ export_to_video(video, "output.mp4", fps=8)
-
-### Stable Video Diffusion
+
+
-[SVD](../api/pipelines/svd) is based on the Stable Diffusion 2.1 model and it is trained on images, then low-resolution videos, and finally a smaller dataset of high-resolution videos. This model generates a short 2-4 second video from an initial image. You can learn more details about model, like micro-conditioning, in the [Stable Video Diffusion](../using-diffusers/svd) guide.
+> [!TIP]
+> HunyuanVideo is a 13B parameter model and requires a lot of memory. Refer to the HunyuanVideo [Quantization](../api/pipelines/hunyuan_video#quantization) guide to learn how to quantize the model. CogVideoX and LTX-Video are more lightweight options that can still generate high-quality videos.
-Begin by loading the [`StableVideoDiffusionPipeline`] and passing an initial image to generate a video from.
+[HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo) features a dual-stream to single-stream diffusion transformer (DiT) for learning video and text tokens separately, and then subsequently concatenating the video and text tokens to combine their information. A single multimodal large language model (MLLM) serves as the text encoder, and videos are also spatio-temporally compressed with a 3D causal VAE.
```py
import torch
-from diffusers import StableVideoDiffusionPipeline
-from diffusers.utils import load_image, export_to_video
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+from diffusers.utils import export_to_video
-pipeline = StableVideoDiffusionPipeline.from_pretrained(
- "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
+)
+pipe = HunyuanVideoPipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
)
-pipeline.enable_model_cpu_offload()
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
-image = image.resize((1024, 576))
+# reduce memory requirements
+pipe.vae.enable_tiling()
+pipe.to("cuda")
-generator = torch.manual_seed(42)
-frames = pipeline(image, decode_chunk_size=8, generator=generator).frames[0]
-export_to_video(frames, "generated.mp4", fps=7)
+video = pipe(
+ prompt="A cat walks on the grass, realistic",
+ height=320,
+ width=512,
+ num_frames=61,
+ num_inference_steps=30,
+).frames[0]
+export_to_video(video, "output.mp4", fps=15)
```
-
-
-
-
initial image
-
-
-
-
generated video
-
+
+
-### I2VGen-XL
-
-[I2VGen-XL](../api/pipelines/i2vgenxl) is a diffusion model that can generate higher resolution videos than SVD and it is also capable of accepting text prompts in addition to images. The model is trained with two hierarchical encoders (detail and global encoder) to better capture low and high-level details in images. These learned details are used to train a video diffusion model which refines the video resolution and details in the generated video.
+
+
-You can use I2VGen-XL by loading the [`I2VGenXLPipeline`], and passing a text and image prompt to generate a video.
+[LTX-Video (LTXV)](https://huggingface.co/Lightricks/LTX-Video) is a diffusion transformer (DiT) with a focus on speed. It generates 768x512 resolution videos at 24 frames per second (fps), enabling near real-time generation of high-quality videos. LTXV is relatively lightweight compared to other modern video generation models, making it possible to run on consumer GPUs.
```py
import torch
-from diffusers import I2VGenXLPipeline
-from diffusers.utils import export_to_gif, load_image
-
-pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
-pipeline.enable_model_cpu_offload()
-
-image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
-image = load_image(image_url).convert("RGB")
+from diffusers import LTXPipeline
+from diffusers.utils import export_to_video
-prompt = "Papers were floating in the air on a table in the library"
-negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
-generator = torch.manual_seed(8888)
+pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16).to("cuda")
-frames = pipeline(
+prompt = "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage."
+video = pipe(
prompt=prompt,
- image=image,
+ width=704,
+ height=480,
+ num_frames=161,
num_inference_steps=50,
- negative_prompt=negative_prompt,
- guidance_scale=9.0,
- generator=generator
).frames[0]
-export_to_gif(frames, "i2v.gif")
+export_to_video(video, "output.mp4", fps=24)
+```
+
+
+
+
+
+
+
+
+> [!TIP]
+> Mochi-1 is a 10B parameter model and requires a lot of memory. Refer to the Mochi [Quantization](../api/pipelines/mochi#quantization) guide to learn how to quantize the model. CogVideoX and LTX-Video are more lightweight options that can still generate high-quality videos.
+
+[Mochi-1](https://huggingface.co/genmo/mochi-1-preview) introduces the Asymmetric Diffusion Transformer (AsymmDiT) and Asymmetric Variational Autoencoder (AsymmVAE) to reduces memory requirements. AsymmVAE causally compresses videos 128x to improve memory efficiency, and AsymmDiT jointly attends to the compressed video tokens and user text tokens. This model is noted for generating videos with high-quality motion dynamics and strong prompt adherence.
+
+```py
+import torch
+from diffusers import MochiPipeline
+from diffusers.utils import export_to_video
+
+pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
+
+# reduce memory requirements
+pipe.enable_model_cpu_offload()
+pipe.enable_vae_tiling()
+
+prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
+video = pipe(prompt, num_frames=84).frames[0]
+export_to_video(video, "output.mp4", fps=30)
+```
+
+
+
+
+
+
+
+
+[StableVideoDiffusion (SVD)](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) is based on the Stable Diffusion 2.1 model and it is trained on images, then low-resolution videos, and finally a smaller dataset of high-resolution videos. This model generates a short 2-4 second video from an initial image.
+
+```py
+import torch
+from diffusers import StableVideoDiffusionPipeline
+from diffusers.utils import load_image, export_to_video
+
+pipeline = StableVideoDiffusionPipeline.from_pretrained(
+ "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
+)
+
+# reduce memory requirements
+pipeline.enable_model_cpu_offload()
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
+image = image.resize((1024, 576))
+
+generator = torch.manual_seed(42)
+frames = pipeline(image, decode_chunk_size=8, generator=generator).frames[0]
+export_to_video(frames, "generated.mp4", fps=7)
```
-
+
initial image
-
+
generated video
-### AnimateDiff
+
+
-[AnimateDiff](../api/pipelines/animatediff) is an adapter model that inserts a motion module into a pretrained diffusion model to animate an image. The adapter is trained on video clips to learn motion which is used to condition the generation process to create a video. It is faster and easier to only train the adapter and it can be loaded into most diffusion models, effectively turning them into "video models".
+[AnimateDiff](https://huggingface.co/guoyww/animatediff) is an adapter model that inserts a motion module into a pretrained diffusion model to animate an image. The adapter is trained on video clips to learn motion which is used to condition the generation process to create a video. It is faster and easier to only train the adapter and it can be loaded into most diffusion models, effectively turning them into “video models”.
-Start by loading a [`MotionAdapter`].
+Load a `MotionAdapter` and pass it to the [`AnimateDiffPipeline`].
```py
import torch
@@ -166,11 +206,6 @@ from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
-```
-
-Then load a finetuned Stable Diffusion model with the [`AnimateDiffPipeline`].
-
-```py
pipeline = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, torch_dtype=torch.float16)
scheduler = DDIMScheduler.from_pretrained(
"emilianJR/epiCRealism",
@@ -181,13 +216,11 @@ scheduler = DDIMScheduler.from_pretrained(
steps_offset=1,
)
pipeline.scheduler = scheduler
+
+# reduce memory requirements
pipeline.enable_vae_slicing()
pipeline.enable_model_cpu_offload()
-```
-Create a prompt and generate the video.
-
-```py
output = pipeline(
prompt="A space rocket with trails of smoke behind it launching into space from the desert, 4k, high resolution",
negative_prompt="bad quality, worse quality, low resolution",
@@ -201,38 +234,11 @@ export_to_gif(frames, "animation.gif")
```
-
+
-### ModelscopeT2V
-
-[ModelscopeT2V](../api/pipelines/text_to_video) adds spatial and temporal convolutions and attention to a UNet, and it is trained on image-text and video-text datasets to enhance what it learns during training. The model takes a prompt, encodes it and creates text embeddings which are denoised by the UNet, and then decoded by a VQGAN into a video.
-
-
-
-ModelScopeT2V generates watermarked videos due to the datasets it was trained on. To use a watermark-free model, try the [cerspense/zeroscope_v2_76w](https://huggingface.co/cerspense/zeroscope_v2_576w) model with the [`TextToVideoSDPipeline`] first, and then upscale it's output with the [cerspense/zeroscope_v2_XL](https://huggingface.co/cerspense/zeroscope_v2_XL) checkpoint using the [`VideoToVideoSDPipeline`].
-
-
-
-Load a ModelScopeT2V checkpoint into the [`DiffusionPipeline`] along with a prompt to generate a video.
-
-```py
-import torch
-from diffusers import DiffusionPipeline
-from diffusers.utils import export_to_video
-
-pipeline = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
-pipeline.enable_model_cpu_offload()
-pipeline.enable_vae_slicing()
-
-prompt = "Confident teddy bear surfer rides the wave in the tropics"
-video_frames = pipeline(prompt).frames[0]
-export_to_video(video_frames, "modelscopet2v.mp4", fps=10)
-```
-
-
-
-
+
+
## Configure model parameters
@@ -548,3 +554,9 @@ If memory is not an issue and you want to optimize for speed, try wrapping the U
+ pipeline.to("cuda")
+ pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) to learn more about supported quantization backends (bitsandbytes, torchao, gguf) and selecting a quantization backend that supports your use case.
diff --git a/docs/source/en/using-diffusers/weighted_prompts.md b/docs/source/en/using-diffusers/weighted_prompts.md
index 712eebc9450c..f310d8f49550 100644
--- a/docs/source/en/using-diffusers/weighted_prompts.md
+++ b/docs/source/en/using-diffusers/weighted_prompts.md
@@ -215,7 +215,7 @@ image
Prompt weighting provides a way to emphasize or de-emphasize certain parts of a prompt, allowing for more control over the generated image. A prompt can include several concepts, which gets turned into contextualized text embeddings. The embeddings are used by the model to condition its cross-attention layers to generate an image (read the Stable Diffusion [blog post](https://huggingface.co/blog/stable_diffusion) to learn more about how it works).
-Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt-weighted embeddings is to use [Compel](https://github.com/damian0815/compel), a text prompt-weighting and blending library. Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [`prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [`negative_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
+Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt embeddings is to use [Stable Diffusion Long Prompt Weighted Embedding](https://github.com/xhinker/sd_embed) (sd_embed). Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [negative_prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
@@ -223,136 +223,99 @@ If your favorite pipeline doesn't have a `prompt_embeds` parameter, please open
-This guide will show you how to weight and blend your prompts with Compel in 🤗 Diffusers.
+This guide will show you how to weight your prompts with sd_embed.
-Before you begin, make sure you have the latest version of Compel installed:
+Before you begin, make sure you have the latest version of sd_embed installed:
-```py
-# uncomment to install in Colab
-#!pip install compel --upgrade
+```bash
+pip install git+https://github.com/xhinker/sd_embed.git@main
```
-For this guide, let's generate an image with the prompt `"a red cat playing with a ball"` using the [`StableDiffusionPipeline`]:
+For this example, let's use [`StableDiffusionXLPipeline`].
```py
-from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
+from diffusers import StableDiffusionXLPipeline, UniPCMultistepScheduler
import torch
-pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_safetensors=True)
+pipe = StableDiffusionXLPipeline.from_pretrained("Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
-
-prompt = "a red cat playing with a ball"
-
-generator = torch.Generator(device="cpu").manual_seed(33)
-
-image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
-image
-```
-
-
-
-
-
-### Weighting
-
-You'll notice there is no "ball" in the image! Let's use compel to upweight the concept of "ball" in the prompt. Create a [`Compel`](https://github.com/damian0815/compel/blob/main/doc/compel.md#compel-objects) object, and pass it a tokenizer and text encoder:
-
-```py
-from compel import Compel
-
-compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
```
-compel uses `+` or `-` to increase or decrease the weight of a word in the prompt. To increase the weight of "ball":
+To upweight or downweight a concept, surround the text with parentheses. More parentheses applies a heavier weight on the text. You can also append a numerical multiplier to the text to indicate how much you want to increase or decrease its weights by.
-
-
-`+` corresponds to the value `1.1`, `++` corresponds to `1.1^2`, and so on. Similarly, `-` corresponds to `0.9` and `--` corresponds to `0.9^2`. Feel free to experiment with adding more `+` or `-` in your prompt!
+| format | multiplier |
+|---|---|
+| `(hippo)` | increase by 1.1x |
+| `((hippo))` | increase by 1.21x |
+| `(hippo:1.5)` | increase by 1.5x |
+| `(hippo:0.5)` | decrease by 4x |
-
+Create a prompt and use a combination of parentheses and numerical multipliers to upweight various text.
```py
-prompt = "a red cat playing with a ball++"
+from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl
+
+prompt = """A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
+This imaginative creature features the distinctive, bulky body of a hippo,
+but with a texture and appearance resembling a golden-brown, crispy waffle.
+The creature might have elements like waffle squares across its skin and a syrup-like sheen.
+It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
+possibly including oversized utensils or plates in the background.
+The image should evoke a sense of playful absurdity and culinary fantasy.
+"""
+
+neg_prompt = """\
+skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
+(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
+extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
+(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
+bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
+(normal quality:2),lowres,((monochrome)),((grayscale))
+"""
```
-Pass the prompt to `compel_proc` to create the new prompt embeddings which are passed to the pipeline:
-
-```py
-prompt_embeds = compel_proc(prompt)
-generator = torch.manual_seed(33)
-
-image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
-image
-```
+Use the `get_weighted_text_embeddings_sdxl` function to generate the prompt embeddings and the negative prompt embeddings. It'll also generated the pooled and negative pooled prompt embeddings since you're using the SDXL model.
-
-
-
-
-To downweight parts of the prompt, use the `-` suffix:
-
-```py
-prompt = "a red------- cat playing with a ball"
-prompt_embeds = compel_proc(prompt)
-
-generator = torch.manual_seed(33)
-
-image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
-image
-```
-
-
-
-
-
-You can even up or downweight multiple concepts in the same prompt:
-
-```py
-prompt = "a red cat++ playing with a ball----"
-prompt_embeds = compel_proc(prompt)
-
-generator = torch.manual_seed(33)
-
-image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
-image
-```
-
-
-
-
-
-### Blending
-
-You can also create a weighted *blend* of prompts by adding `.blend()` to a list of prompts and passing it some weights. Your blend may not always produce the result you expect because it breaks some assumptions about how the text encoder functions, so just have fun and experiment with it!
+> [!TIP]
+> You can safely ignore the error message below about the token index length exceeding the models maximum sequence length. All your tokens will be used in the embedding process.
+>
+> ```
+> Token indices sequence length is longer than the specified maximum sequence length for this model
+> ```
```py
-prompt_embeds = compel_proc('("a red cat playing with a ball", "jungle").blend(0.7, 0.8)')
-generator = torch.Generator(device="cuda").manual_seed(33)
+(
+ prompt_embeds,
+ prompt_neg_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds
+) = get_weighted_text_embeddings_sdxl(
+ pipe,
+ prompt=prompt,
+ neg_prompt=neg_prompt
+)
-image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
+image = pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=prompt_neg_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=30,
+ height=1024,
+ width=1024 + 512,
+ guidance_scale=4.0,
+ generator=torch.Generator("cuda").manual_seed(2)
+).images[0]
image
```
-
+
-### Conjunction
-
-A conjunction diffuses each prompt independently and concatenates their results by their weighted sum. Add `.and()` to the end of a list of prompts to create a conjunction:
-
-```py
-prompt_embeds = compel_proc('["a red cat", "playing with a", "ball"].and()')
-generator = torch.Generator(device="cuda").manual_seed(55)
-
-image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
-image
-```
-
-
-
-
+> [!TIP]
+> Refer to the [sd_embed](https://github.com/xhinker/sd_embed) repository for additional details about long prompt weighting for FLUX.1, Stable Cascade, and Stable Diffusion 1.5.
### Textual inversion
@@ -363,35 +326,63 @@ Create a pipeline and use the [`~loaders.TextualInversionLoaderMixin.load_textua
```py
import torch
from diffusers import StableDiffusionPipeline
-from compel import Compel, DiffusersTextualInversionManager
pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16,
- use_safetensors=True, variant="fp16").to("cuda")
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+).to("cuda")
pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
```
-Compel provides a `DiffusersTextualInversionManager` class to simplify prompt weighting with textual inversion. Instantiate `DiffusersTextualInversionManager` and pass it to the `Compel` class:
+Add the `
` text to the prompt to trigger the textual inversion.
```py
-textual_inversion_manager = DiffusersTextualInversionManager(pipe)
-compel_proc = Compel(
- tokenizer=pipe.tokenizer,
- text_encoder=pipe.text_encoder,
- textual_inversion_manager=textual_inversion_manager)
+from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
+
+prompt = """ A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
+This imaginative creature features the distinctive, bulky body of a hippo,
+but with a texture and appearance resembling a golden-brown, crispy waffle.
+The creature might have elements like waffle squares across its skin and a syrup-like sheen.
+It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
+possibly including oversized utensils or plates in the background.
+The image should evoke a sense of playful absurdity and culinary fantasy.
+"""
+
+neg_prompt = """\
+skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
+(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
+extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
+(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
+bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
+(normal quality:2),lowres,((monochrome)),((grayscale))
+"""
```
-Incorporate the concept to condition a prompt with using the `` syntax:
+Use the `get_weighted_text_embeddings_sd15` function to generate the prompt embeddings and the negative prompt embeddings.
```py
-prompt_embeds = compel_proc('("A red cat++ playing with a ball ")')
+(
+ prompt_embeds,
+ prompt_neg_embeds,
+) = get_weighted_text_embeddings_sd15(
+ pipe,
+ prompt=prompt,
+ neg_prompt=neg_prompt
+)
-image = pipe(prompt_embeds=prompt_embeds).images[0]
+image = pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=prompt_neg_embeds,
+ height=768,
+ width=896,
+ guidance_scale=4.0,
+ generator=torch.Generator("cuda").manual_seed(2)
+).images[0]
image
```
-
+
### DreamBooth
@@ -401,70 +392,44 @@ image
```py
import torch
from diffusers import DiffusionPipeline, UniPCMultistepScheduler
-from compel import Compel
pipe = DiffusionPipeline.from_pretrained("sd-dreambooth-library/dndcoverart-v1", torch_dtype=torch.float16).to("cuda")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
```
-Create a `Compel` class with a tokenizer and text encoder, and pass your prompt to it. Depending on the model you use, you'll need to incorporate the model's unique identifier into your prompt. For example, the `dndcoverart-v1` model uses the identifier `dndcoverart`:
-
-```py
-compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
-prompt_embeds = compel_proc('("magazine cover of a dndcoverart dragon, high quality, intricate details, larry elmore art style").and()')
-image = pipe(prompt_embeds=prompt_embeds).images[0]
-image
-```
-
-
-
-
-
-### Stable Diffusion XL
-
-Stable Diffusion XL (SDXL) has two tokenizers and text encoders so it's usage is a bit different. To address this, you should pass both tokenizers and encoders to the `Compel` class:
+Depending on the model you use, you'll need to incorporate the model's unique identifier into your prompt. For example, the `dndcoverart-v1` model uses the identifier `dndcoverart`:
```py
-from compel import Compel, ReturnedEmbeddingsType
-from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- variant="fp16",
- use_safetensors=True,
- torch_dtype=torch.float16
-).to("cuda")
-
-compel = Compel(
- tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
- text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
- requires_pooled=[False, True]
+from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
+
+prompt = """dndcoverart of A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
+This imaginative creature features the distinctive, bulky body of a hippo,
+but with a texture and appearance resembling a golden-brown, crispy waffle.
+The creature might have elements like waffle squares across its skin and a syrup-like sheen.
+It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
+possibly including oversized utensils or plates in the background.
+The image should evoke a sense of playful absurdity and culinary fantasy.
+"""
+
+neg_prompt = """\
+skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
+(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
+extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
+(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
+bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
+(normal quality:2),lowres,((monochrome)),((grayscale))
+"""
+
+(
+ prompt_embeds
+ , prompt_neg_embeds
+) = get_weighted_text_embeddings_sd15(
+ pipe
+ , prompt = prompt
+ , neg_prompt = neg_prompt
)
```
-This time, let's upweight "ball" by a factor of 1.5 for the first prompt, and downweight "ball" by 0.6 for the second prompt. The [`StableDiffusionXLPipeline`] also requires [`pooled_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline.__call__.pooled_prompt_embeds) (and optionally [`negative_pooled_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline.__call__.negative_pooled_prompt_embeds)) so you should pass those to the pipeline along with the conditioning tensors:
-
-```py
-# apply weights
-prompt = ["a red cat playing with a (ball)1.5", "a red cat playing with a (ball)0.6"]
-conditioning, pooled = compel(prompt)
-
-# generate image
-generator = [torch.Generator().manual_seed(33) for _ in range(len(prompt))]
-images = pipeline(prompt_embeds=conditioning, pooled_prompt_embeds=pooled, generator=generator, num_inference_steps=30).images
-make_image_grid(images, rows=1, cols=2)
-```
-
-
-
-
-
"a red cat playing with a (ball)1.5"
-
-
-
-
"a red cat playing with a (ball)0.6"
-
+
+
diff --git a/docs/source/en/using-diffusers/write_own_pipeline.md b/docs/source/en/using-diffusers/write_own_pipeline.md
index bdcd4e5d1307..283397ff3e9d 100644
--- a/docs/source/en/using-diffusers/write_own_pipeline.md
+++ b/docs/source/en/using-diffusers/write_own_pipeline.md
@@ -106,7 +106,7 @@ Let's try it out!
## Deconstruct the Stable Diffusion pipeline
-Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder to convert the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
+Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder converts the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
As you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.
diff --git a/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
index d7211d6b9471..d708dfa59dad 100644
--- a/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
+++ b/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
@@ -121,7 +121,7 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inferen
### 이미지 결과물을 정제하기
-[base 모델 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)에서, StableDiffusion-XL 또한 고주파 품질을 향상시키는 이미지를 생성하기 위해 낮은 노이즈 단계 이미지를 제거하는데 특화된 [refiner 체크포인트](huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 포함하고 있습니다. 이 refiner 체크포인트는 이미지 품질을 향상시키기 위해 base 체크포인트를 실행한 후 "두 번째 단계" 파이프라인에 사용될 수 있습니다.
+[base 모델 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)에서, StableDiffusion-XL 또한 고주파 품질을 향상시키는 이미지를 생성하기 위해 낮은 노이즈 단계 이미지를 제거하는데 특화된 [refiner 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 포함하고 있습니다. 이 refiner 체크포인트는 이미지 품질을 향상시키기 위해 base 체크포인트를 실행한 후 "두 번째 단계" 파이프라인에 사용될 수 있습니다.
refiner를 사용할 때, 쉽게 사용할 수 있습니다
- 1.) base 모델과 refiner을 사용하는데, 이는 *Denoisers의 앙상블*을 위한 첫 번째 제안된 [eDiff-I](https://research.nvidia.com/labs/dir/eDiff-I/)를 사용하거나
@@ -215,7 +215,7 @@ image = refiner(
#### 2.) 노이즈가 완전히 제거된 기본 이미지에서 이미지 출력을 정제하기
-일반적인 [`StableDiffusionImg2ImgPipeline`] 방식에서, 기본 모델에서 생성된 완전히 노이즈가 제거된 이미지는 [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 사용해 더 향상시킬 수 있습니다.
+일반적인 [`StableDiffusionImg2ImgPipeline`] 방식에서, 기본 모델에서 생성된 완전히 노이즈가 제거된 이미지는 [refiner checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 사용해 더 향상시킬 수 있습니다.
이를 위해, 보통의 "base" text-to-image 파이프라인을 수행 후에 image-to-image 파이프라인으로써 refiner를 실행시킬 수 있습니다. base 모델의 출력을 잠재 공간에 남겨둘 수 있습니다.
diff --git a/docs/source/ko/training/controlnet.md b/docs/source/ko/training/controlnet.md
index afdd2c8e0004..ce83cab54e8b 100644
--- a/docs/source/ko/training/controlnet.md
+++ b/docs/source/ko/training/controlnet.md
@@ -66,12 +66,6 @@ from accelerate.utils import write_basic_config
write_basic_config()
```
-## 원을 채우는 데이터셋
-
-원본 데이터셋은 ControlNet [repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip)에 올라와있지만, 우리는 [여기](https://huggingface.co/datasets/fusing/fill50k)에 새롭게 다시 올려서 🤗 Datasets 과 호환가능합니다. 그래서 학습 스크립트 상에서 데이터 불러오기를 다룰 수 있습니다.
-
-우리의 학습 예시는 원래 ControlNet의 학습에 쓰였던 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)을 사용합니다. 그렇지만 ControlNet은 대응되는 어느 Stable Diffusion 모델([`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)) 혹은 [`stabilityai/stable-diffusion-2-1`](https://huggingface.co/stabilityai/stable-diffusion-2-1)의 증가를 위해 학습될 수 있습니다.
-
자체 데이터셋을 사용하기 위해서는 [학습을 위한 데이터셋 생성하기](create_dataset) 가이드를 확인하세요.
## 학습
diff --git a/docs/source/ko/training/create_dataset.md b/docs/source/ko/training/create_dataset.md
index 6987a6c9d4f0..401a73ebf237 100644
--- a/docs/source/ko/training/create_dataset.md
+++ b/docs/source/ko/training/create_dataset.md
@@ -1,7 +1,7 @@
# 학습을 위한 데이터셋 만들기
[Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) 에는 모델 교육을 위한 많은 데이터셋이 있지만,
-관심이 있거나 사용하고 싶은 데이터셋을 찾을 수 없는 경우 🤗 [Datasets](hf.co/docs/datasets) 라이브러리를 사용하여 데이터셋을 만들 수 있습니다.
+관심이 있거나 사용하고 싶은 데이터셋을 찾을 수 없는 경우 🤗 [Datasets](https://huggingface.co/docs/datasets) 라이브러리를 사용하여 데이터셋을 만들 수 있습니다.
데이터셋 구조는 모델을 학습하려는 작업에 따라 달라집니다.
가장 기본적인 데이터셋 구조는 unconditional 이미지 생성과 같은 작업을 위한 이미지 디렉토리입니다.
또 다른 데이터셋 구조는 이미지 디렉토리와 text-to-image 생성과 같은 작업에 해당하는 텍스트 캡션이 포함된 텍스트 파일일 수 있습니다.
diff --git a/docs/source/ko/training/lora.md b/docs/source/ko/training/lora.md
index 6b905951aafc..85ed1dda0b81 100644
--- a/docs/source/ko/training/lora.md
+++ b/docs/source/ko/training/lora.md
@@ -36,7 +36,7 @@ specific language governing permissions and limitations under the License.
[cloneofsimo](https://github.com/cloneofsimo)는 인기 있는 [lora](https://github.com/cloneofsimo/lora) GitHub 리포지토리에서 Stable Diffusion을 위한 LoRA 학습을 최초로 시도했습니다. 🧨 Diffusers는 [text-to-image 생성](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) 및 [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)을 지원합니다. 이 가이드는 두 가지를 모두 수행하는 방법을 보여줍니다.
-모델을 저장하거나 커뮤니티와 공유하려면 Hugging Face 계정에 로그인하세요(아직 계정이 없는 경우 [생성](hf.co/join)하세요):
+모델을 저장하거나 커뮤니티와 공유하려면 Hugging Face 계정에 로그인하세요(아직 계정이 없는 경우 [생성](https://huggingface.co/join)하세요):
```bash
huggingface-cli login
diff --git a/docs/source/ko/tutorials/basic_training.md b/docs/source/ko/tutorials/basic_training.md
index f34507b50c9d..5b08bb39d602 100644
--- a/docs/source/ko/tutorials/basic_training.md
+++ b/docs/source/ko/tutorials/basic_training.md
@@ -76,7 +76,7 @@ huggingface-cli login
... output_dir = "ddpm-butterflies-128" # 로컬 및 HF Hub에 저장되는 모델명
... push_to_hub = True # 저장된 모델을 HF Hub에 업로드할지 여부
-... hub_private_repo = False
+... hub_private_repo = None
... overwrite_output_dir = True # 노트북을 다시 실행할 때 이전 모델에 덮어씌울지
... seed = 0
diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml
index 41d5e95a4230..6416c468a8e9 100644
--- a/docs/source/zh/_toctree.yml
+++ b/docs/source/zh/_toctree.yml
@@ -5,6 +5,8 @@
title: 快速入门
- local: stable_diffusion
title: 有效和高效的扩散
+ - local: consisid
+ title: 身份保持的文本到视频生成
- local: installation
title: 安装
title: 开始
diff --git a/docs/source/zh/consisid.md b/docs/source/zh/consisid.md
new file mode 100644
index 000000000000..2f404499fc69
--- /dev/null
+++ b/docs/source/zh/consisid.md
@@ -0,0 +1,100 @@
+
+# ConsisID
+
+[ConsisID](https://github.com/PKU-YuanGroup/ConsisID)是一种身份保持的文本到视频生成模型,其通过频率分解在生成的视频中保持面部一致性。它具有以下特点:
+
+- 基于频率分解:将人物ID特征解耦为高频和低频部分,从频域的角度分析DIT架构的特性,并且基于此特性设计合理的控制信息注入方式。
+
+- 一致性训练策略:我们提出粗到细训练策略、动态掩码损失、动态跨脸损失,进一步提高了模型的泛化能力和身份保持效果。
+
+
+- 推理无需微调:之前的方法在推理前,需要对输入id进行case-by-case微调,时间和算力开销较大,而我们的方法是tuning-free的。
+
+
+本指南将指导您使用 ConsisID 生成身份保持的视频。
+
+## Load Model Checkpoints
+模型权重可以存储在Hub上或本地的单独子文件夹中,在这种情况下,您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。
+
+
+```python
+# !pip install consisid_eva_clip insightface facexlib
+import torch
+from diffusers import ConsisIDPipeline
+from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
+from huggingface_hub import snapshot_download
+
+# Download ckpts
+snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
+
+# Load face helper model to preprocess input face image
+face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
+
+# Load consisid base model
+pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+```
+
+## Identity-Preserving Text-to-Video
+对于身份保持的文本到视频生成,需要输入文本提示和包含清晰面部(例如,最好是半身或全身)的图像。默认情况下,ConsisID 会生成 720x480 的视频以获得最佳效果。
+
+```python
+from diffusers.utils import export_to_video
+
+prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."
+image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true"
+
+id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2, eva_transform_mean, eva_transform_std, face_main_model, "cuda", torch.bfloat16, image, is_align_face=True)
+
+video = pipe(image=image, prompt=prompt, num_inference_steps=50, guidance_scale=6.0, use_dynamic_cfg=False, id_vit_hidden=id_vit_hidden, id_cond=id_cond, kps_cond=face_kps, generator=torch.Generator("cuda").manual_seed(42))
+export_to_video(video.frames[0], "output.mp4", fps=8)
+```
+
+
+ Face Image
+ Video
+ Description
+
+
+
+ The video, in a beautifully crafted animated style, features a confident woman riding a horse through a lush forest clearing. Her expression is focused yet serene as she adjusts her wide-brimmed hat with a practiced hand. She wears a flowy bohemian dress, which moves gracefully with the rhythm of the horse, the fabric flowing fluidly in the animated motion. The dappled sunlight filters through the trees, casting soft, painterly patterns on the forest floor. Her posture is poised, showing both control and elegance as she guides the horse with ease. The animation's gentle, fluid style adds a dreamlike quality to the scene, with the woman’s calm demeanor and the peaceful surroundings evoking a sense of freedom and harmony.
+
+
+
+
+ The video, in a captivating animated style, shows a woman standing in the center of a snowy forest, her eyes narrowed in concentration as she extends her hand forward. She is dressed in a deep blue cloak, her breath visible in the cold air, which is rendered with soft, ethereal strokes. A faint smile plays on her lips as she summons a wisp of ice magic, watching with focus as the surrounding trees and ground begin to shimmer and freeze, covered in delicate ice crystals. The animation’s fluid motion brings the magic to life, with the frost spreading outward in intricate, sparkling patterns. The environment is painted with soft, watercolor-like hues, enhancing the magical, dreamlike atmosphere. The overall mood is serene yet powerful, with the quiet winter air amplifying the delicate beauty of the frozen scene.
+
+
+
+
+ The animation features a whimsical portrait of a balloon seller standing in a gentle breeze, captured with soft, hazy brushstrokes that evoke the feel of a serene spring day. His face is framed by a gentle smile, his eyes squinting slightly against the sun, while a few wisps of hair flutter in the wind. He is dressed in a light, pastel-colored shirt, and the balloons around him sway with the wind, adding a sense of playfulness to the scene. The background blurs softly, with hints of a vibrant market or park, enhancing the light-hearted, yet tender mood of the moment.
+
+
+
+
+ The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.
+
+
+
+
+ The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.
+
+
+
+## Resources
+
+通过以下资源了解有关 ConsisID 的更多信息:
+
+- 一段 [视频](https://www.youtube.com/watch?v=PhlgC-bI5SQ) 演示了 ConsisID 的主要功能;
+- 有关更多详细信息,请参阅研究论文 [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440)。
diff --git a/examples/README.md b/examples/README.md
index c27507040545..7cdf25999ac3 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -40,9 +40,9 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
-| [**ControlNet**](./controlnet) | ✅ | ✅ | -
-| [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | -
-| [**Reinforcement Learning for Control**](./reinforcement_learning) | - | - | coming soon.
+| [**ControlNet**](./controlnet) | ✅ | ✅ | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
+| [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/InstructPix2Pix_using_diffusers.ipynb)
+| [**Reinforcement Learning for Control**](./reinforcement_learning) | - | - | [Notebook1](https://github.com/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_for_control.ipynb), [Notebook2](https://github.com/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb)
## Community
diff --git a/examples/advanced_diffusion_training/README.md b/examples/advanced_diffusion_training/README.md
index cd8c5feda9f0..504ae1471f44 100644
--- a/examples/advanced_diffusion_training/README.md
+++ b/examples/advanced_diffusion_training/README.md
@@ -67,6 +67,17 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
+```bash
+huggingface-cli login
+```
+This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
+
+> [!NOTE]
+> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
+> `pip install wandb`
+> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
+
### Pivotal Tuning
**Training with text encoder(s)**
diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md
index 8817431bede5..f2a571d5eae4 100644
--- a/examples/advanced_diffusion_training/README_flux.md
+++ b/examples/advanced_diffusion_training/README_flux.md
@@ -65,16 +65,27 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
+```bash
+huggingface-cli login
+```
+This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
+
+> [!NOTE]
+> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
+> `pip install wandb`
+> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
+
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
-applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
+applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
> [!NOTE]
-> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
+> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
index 3db6896228de..b8194507d822 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -74,7 +74,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -227,7 +227,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext()
with autocast_ctx:
@@ -378,7 +378,7 @@ def parse_args(input_args=None):
default=None,
help="the concept to use to initialize the new inserted tokens when training with "
"--train_text_encoder_ti = True. By default, new tokens (
) are initialized with random value. "
- "Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. "
+ "Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. "
"--num_new_tokens_per_abstraction is ignored when initializer_concept is provided",
)
parser.add_argument(
@@ -662,7 +662,7 @@ def parse_args(input_args=None):
type=str,
default=None,
help=(
- "The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. "
+ "The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. "
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
),
)
@@ -880,9 +880,7 @@ def save_embeddings(self, file_path: str):
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
for idx, text_encoder in enumerate(self.text_encoders):
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
- embeds = (
- text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
- )
+ embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
new_token_embeddings = embeds.weight.data[train_ids]
@@ -904,9 +902,7 @@ def device(self):
@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
- embeds = (
- text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
- )
+ embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
embeds.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
@@ -1650,6 +1646,8 @@ def save_model_hook(models, weights, output_dir):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1747,7 +1745,7 @@ def load_model_hook(models, input_dir):
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
- if "token_embedding" in name:
+ if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
@@ -1776,15 +1774,10 @@ def load_model_hook(models, input_dir):
if not args.enable_t5_ti:
# pure textual inversion - only clip
if pure_textual_inversion:
- params_to_optimize = [
- text_parameters_one_with_lr,
- ]
+ params_to_optimize = [text_parameters_one_with_lr]
te_idx = 0
else: # regular te training or regular pivotal for clip
- params_to_optimize = [
- transformer_parameters_with_lr,
- text_parameters_one_with_lr,
- ]
+ params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
te_idx = 1
elif args.enable_t5_ti:
# pivotal tuning of clip & t5
@@ -1807,9 +1800,7 @@ def load_model_hook(models, input_dir):
]
te_idx = 1
else:
- params_to_optimize = [
- transformer_parameters_with_lr,
- ]
+ params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
@@ -1869,7 +1860,6 @@ def load_model_hook(models, input_dir):
params_to_optimize[-1]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -2160,6 +2150,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
+ elems_to_repeat = 1
if freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
@@ -2174,17 +2165,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
max_sequence_length=args.max_sequence_length,
add_special_tokens=add_special_tokens_t5,
)
+ else:
+ elems_to_repeat = len(prompts)
if not freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
- text_input_ids_list=[tokens_one, tokens_two],
+ text_input_ids_list=[
+ tokens_one.repeat(elems_to_repeat, 1),
+ tokens_two.repeat(elems_to_repeat, 1),
+ ],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=prompts,
)
-
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].sample()
@@ -2198,8 +2193,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
- model_input.shape[2],
- model_input.shape[3],
+ model_input.shape[2] // 2,
+ model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
@@ -2253,8 +2248,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
- height=int(model_input.shape[2] * vae_scale_factor / 2),
- width=int(model_input.shape[3] * vae_scale_factor / 2),
+ height=model_input.shape[2] * vae_scale_factor,
+ width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor,
)
@@ -2377,6 +2372,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
epoch=epoch,
torch_dtype=weight_dtype,
)
+ images = None
+ del pipeline
+
if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
@@ -2454,6 +2452,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
+ images = None
+ del pipeline
accelerator.end_training()
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 7e1a0298ba1d..8cd1d777c00c 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,7 +39,7 @@
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
-from peft import LoraConfig
+from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
@@ -59,19 +59,21 @@
)
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
-from diffusers.training_utils import compute_snr
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import (
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
+ convert_unet_state_dict_to_peft,
is_wandb_available,
)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -79,30 +81,27 @@
def save_model_card(
repo_id: str,
use_dora: bool,
- images=None,
- base_model=str,
+ images: list = None,
+ base_model: str = None,
train_text_encoder=False,
train_text_encoder_ti=False,
token_abstraction_dict=None,
- instance_prompt=str,
- validation_prompt=str,
+ instance_prompt=None,
+ validation_prompt=None,
repo_folder=None,
vae_path=None,
):
- img_str = "widget:\n"
lora = "lora" if not use_dora else "dora"
- for i, image in enumerate(images):
- image.save(os.path.join(repo_folder, f"image_{i}.png"))
- img_str += f"""
- - text: '{validation_prompt if validation_prompt else ' ' }'
- output:
- url:
- "image_{i}.png"
- """
- if not images:
- img_str += f"""
- - text: '{instance_prompt}'
- """
+
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+ else:
+ widget_dict.append({"text": instance_prompt})
embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"", "", re.sub(r"", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt))
@@ -137,24 +136,7 @@ def save_model_card(
trigger_str += f"""
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""
-
- yaml = f"""---
-tags:
-- stable-diffusion
-- stable-diffusion-diffusers
-- diffusers-training
-- text-to-image
-- diffusers
-- {lora}
-- template:sd-lora
-{img_str}
-base_model: {base_model}
-instance_prompt: {instance_prompt}
-license: openrail++
----
-"""
-
- model_card = f"""
+ model_description = f"""
# SD1.5 LoRA DreamBooth - {repo_id}
@@ -178,7 +160,7 @@ def save_model_card(
from diffusers import AutoPipelineForText2Image
import torch
{diffusers_imports_pivotal}
-pipeline = AutoPipelineForText2Image.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to('cuda')
+pipeline = AutoPipelineForText2Image.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
{diffusers_example_pivotal}
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
@@ -202,8 +184,28 @@ def save_model_card(
Special VAE used for training: {vae_path}.
"""
- with open(os.path.join(repo_folder, "README.md"), "w") as f:
- f.write(yaml + model_card)
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ inference=True,
+ widget=widget_dict,
+ )
+
+ tags = [
+ "text-to-image",
+ "diffusers",
+ "diffusers-training",
+ lora,
+ "template:sd-lora" "stable-diffusion",
+ "stable-diffusion-diffusers",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
def import_model_class_from_model_name_or_path(
@@ -660,7 +662,7 @@ def parse_args(input_args=None):
action="store_true",
default=False,
help=(
- "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
+ "Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
@@ -1318,6 +1320,37 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
+ lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
@@ -1358,10 +1391,7 @@ def load_model_hook(models, input_dir):
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
- params_to_optimize = [
- unet_lora_parameters_with_lr,
- text_lora_parameters_one_with_lr,
- ]
+ params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr]
else:
params_to_optimize = [unet_lora_parameters_with_lr]
@@ -1423,7 +1453,6 @@ def load_model_hook(models, input_dir):
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -1854,7 +1883,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = (
+ torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ if args.seed is not None
+ else None
+ )
pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available():
@@ -1958,7 +1991,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
)
# run inference
pipeline = pipeline.to(accelerator.device)
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = (
+ torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ )
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index 5222c8afe6f1..f8253715e64d 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -71,6 +71,7 @@
convert_unet_state_dict_to_peft,
is_wandb_available,
)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -79,7 +80,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -101,7 +102,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
def save_model_card(
repo_id: str,
use_dora: bool,
- images=None,
+ images: list = None,
base_model: str = None,
train_text_encoder=False,
train_text_encoder_ti=False,
@@ -111,20 +112,17 @@ def save_model_card(
repo_folder=None,
vae_path=None,
):
- img_str = "widget:\n"
lora = "lora" if not use_dora else "dora"
- for i, image in enumerate(images):
- image.save(os.path.join(repo_folder, f"image_{i}.png"))
- img_str += f"""
- - text: '{validation_prompt if validation_prompt else ' ' }'
- output:
- url:
- "image_{i}.png"
- """
- if not images:
- img_str += f"""
- - text: '{instance_prompt}'
- """
+
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+ else:
+ widget_dict.append({"text": instance_prompt})
embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"", "", re.sub(r"", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt))
@@ -169,23 +167,7 @@ def save_model_card(
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""
- yaml = f"""---
-tags:
-- stable-diffusion-xl
-- stable-diffusion-xl-diffusers
-- diffusers-training
-- text-to-image
-- diffusers
-- {lora}
-- template:sd-lora
-{img_str}
-base_model: {base_model}
-instance_prompt: {instance_prompt}
-license: openrail++
----
-"""
-
- model_card = f"""
+ model_description = f"""
# SDXL LoRA DreamBooth - {repo_id}
@@ -234,8 +216,25 @@ def save_model_card(
{license}
"""
- with open(os.path.join(repo_folder, "README.md"), "w") as f:
- f.write(yaml + model_card)
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers",
+ lora,
+ "template:sd-lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
def log_validation(
@@ -269,7 +268,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
@@ -773,7 +772,7 @@ def parse_args(input_args=None):
action="store_true",
default=False,
help=(
- "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
+ "Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
@@ -1794,7 +1793,6 @@ def load_model_hook(models, input_dir):
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -1876,7 +1874,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
- # if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
+ # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
add_special_tokens = True if args.train_text_encoder_ti else False
if not train_dataset.custom_instance_prompts:
diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py
index ede51775dd8f..df44a0a63aeb 100644
--- a/examples/amused/train_amused.py
+++ b/examples/amused/train_amused.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
index 0fdca2850784..eed8305f4fbc 100644
--- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py
+++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
@@ -61,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -722,7 +722,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
videos = []
for _ in range(args.num_validation_videos):
@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings(
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
+ device=device,
)
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@@ -947,7 +946,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
index ece2228147e2..74ea98cbac5e 100644
--- a/examples/cogvideo/train_cogvideox_lora.py
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -52,7 +52,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -739,7 +739,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
videos = []
for _ in range(args.num_validation_videos):
@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings(
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
+ device=device,
)
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@@ -969,7 +968,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
diff --git a/examples/cogview4-control/README.md b/examples/cogview4-control/README.md
new file mode 100644
index 000000000000..746a99a1a41b
--- /dev/null
+++ b/examples/cogview4-control/README.md
@@ -0,0 +1,201 @@
+# Training CogView4 Control
+
+This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources:
+
+To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`.
+
+> [!NOTE]
+> **Gated model**
+>
+> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+
+```bash
+huggingface-cli login
+```
+
+The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
+
+```bash
+accelerate launch train_control_lora_cogview4.py \
+ --pretrained_model_name_or_path="THUDM/CogView4-6B" \
+ --dataset_name="raulc0399/open_pose_controlnet" \
+ --output_dir="pose-control-lora" \
+ --mixed_precision="bf16" \
+ --train_batch_size=1 \
+ --rank=64 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=5000 \
+ --validation_image="openpose.png" \
+ --validation_prompt="A couple, 4k photo, highly detailed" \
+ --offload \
+ --seed="0" \
+ --push_to_hub
+```
+
+`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
+
+You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
+
+The training script exposes additional CLI args that might be useful to experiment with:
+
+* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
+* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
+* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
+
+### Training with DeepSpeed
+
+It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
+
+```yaml
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+```
+
+And then while launching training, pass the config file:
+
+```bash
+accelerate launch --config_file=CONFIG_FILE.yaml ...
+```
+
+### Inference
+
+The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
+
+```bash
+pip install controlnet_aux
+```
+
+And then we are ready:
+
+```py
+from controlnet_aux import OpenposeDetector
+from diffusers import CogView4ControlPipeline
+from diffusers.utils import load_image
+from PIL import Image
+import numpy as np
+import torch
+
+pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda")
+pipe.load_lora_weights("...") # change this.
+
+open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
+
+# prepare pose condition.
+url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
+image = load_image(url)
+image = open_pose(image, detect_resolution=512, image_resolution=1024)
+image = np.array(image)[:, :, ::-1]
+image = Image.fromarray(np.uint8(image))
+
+prompt = "A couple, 4k photo, highly detailed"
+
+gen_images = pipe(
+ prompt=prompt,
+ control_image=image,
+ num_inference_steps=50,
+ joint_attention_kwargs={"scale": 0.9},
+ guidance_scale=25.,
+).images[0]
+gen_images.save("output.png")
+```
+
+## Full fine-tuning
+
+We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command:
+
+```bash
+accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \
+ --pretrained_model_name_or_path="THUDM/CogView4-6B" \
+ --dataset_name="raulc0399/open_pose_controlnet" \
+ --output_dir="pose-control" \
+ --mixed_precision="bf16" \
+ --train_batch_size=2 \
+ --dataloader_num_workers=4 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --proportion_empty_prompts=0.2 \
+ --learning_rate=5e-5 \
+ --adam_weight_decay=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="cosine" \
+ --lr_warmup_steps=1000 \
+ --checkpointing_steps=1000 \
+ --max_train_steps=10000 \
+ --validation_steps=200 \
+ --validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
+ --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
+ --offload \
+ --seed="0" \
+ --push_to_hub
+```
+
+Change the `validation_image` and `validation_prompt` as needed.
+
+For inference, this time, we will run:
+
+```py
+from controlnet_aux import OpenposeDetector
+from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel
+from diffusers.utils import load_image
+from PIL import Image
+import numpy as np
+import torch
+
+transformer = CogView4Transformer2DModel.from_pretrained("...") # change this.
+pipe = CogView4ControlPipeline.from_pretrained(
+ "THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16
+).to("cuda")
+
+open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
+
+# prepare pose condition.
+url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
+image = load_image(url)
+image = open_pose(image, detect_resolution=512, image_resolution=1024)
+image = np.array(image)[:, :, ::-1]
+image = Image.fromarray(np.uint8(image))
+
+prompt = "A couple, 4k photo, highly detailed"
+
+gen_images = pipe(
+ prompt=prompt,
+ control_image=image,
+ num_inference_steps=50,
+ guidance_scale=25.,
+).images[0]
+gen_images.save("output.png")
+```
+
+## Things to note
+
+* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
+* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
+* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
\ No newline at end of file
diff --git a/examples/cogview4-control/requirements.txt b/examples/cogview4-control/requirements.txt
new file mode 100644
index 000000000000..6c5ec2e03f9a
--- /dev/null
+++ b/examples/cogview4-control/requirements.txt
@@ -0,0 +1,6 @@
+transformers==4.47.0
+wandb
+torch
+torchvision
+accelerate==1.2.0
+peft>=0.14.0
diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py
new file mode 100644
index 000000000000..506ca0225bf7
--- /dev/null
+++ b/examples/cogview4-control/train_control_cogview4.py
@@ -0,0 +1,1242 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ CogView4ControlPipeline,
+ CogView4Transformer2DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ free_memory,
+)
+from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.33.0.dev0")
+
+logger = get_logger(__name__)
+
+NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
+
+
+def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):
+ pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()
+ pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor
+ return pixel_latents.to(weight_dtype)
+
+
+def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):
+ logger.info("Running validation... ")
+
+ if not is_final_validation:
+ cogview4_transformer = accelerator.unwrap_model(cogview4_transformer)
+ pipeline = CogView4ControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=cogview4_transformer,
+ torch_dtype=weight_dtype,
+ )
+ else:
+ transformer = CogView4Transformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
+ pipeline = CogView4ControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=transformer,
+ torch_dtype=weight_dtype,
+ )
+
+ pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+ if is_final_validation or torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = load_image(validation_image)
+ # maybe need to inference on 1024 to get a good image
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ prompt=validation_prompt,
+ control_image=validation_image,
+ num_inference_steps=50,
+ guidance_scale=args.guidance_scale,
+ max_sequence_length=args.max_sequence_length,
+ generator=generator,
+ height=args.resolution,
+ width=args.resolution,
+ ).images[0]
+ image = image.resize((args.resolution, args.resolution))
+ images.append(image)
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ formatted_images = []
+ formatted_images.append(np.asarray(validation_image))
+ for image in images:
+ formatted_images.append(np.asarray(image))
+ formatted_images = np.stack(formatted_images)
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+
+ elif tracker.name == "wandb":
+ formatted_images = []
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ free_memory()
+ return image_logs
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# cogview4-control-{repo_id}
+
+These are Control weights trained on {base_model} with new type of conditioning.
+{img_str}
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogView4-6b/blob/main/LICENSE.md)
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "cogview4",
+ "cogview4-diffusers",
+ "text-to-image",
+ "diffusers",
+ "control",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a CogView4 Control training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="cogview4-control",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--max_sequence_length", type=int, default=128, help="The maximum sequence length for the prompt."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the control conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=1,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="cogview4_train_control",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument(
+ "--jsonl_for_train",
+ type=str,
+ default=None,
+ help="Path to the jsonl file containing the training data.",
+ )
+ parser.add_argument(
+ "--only_target_transformer_blocks",
+ action="store_true",
+ help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the guidance scale used for transformer.",
+ )
+
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.jsonl_for_train is None:
+ raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`")
+
+ if args.dataset_name is not None and args.jsonl_for_train is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the cogview4 transformer."
+ )
+
+ return args
+
+
+def get_train_dataset(args, accelerator):
+ dataset = None
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ if args.jsonl_for_train is not None:
+ # load from json
+ dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
+ dataset = dataset.flatten_indices()
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ with accelerator.main_process_first():
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(args.max_train_samples))
+ return train_dataset
+
+
+def prepare_train_dataset(dataset, accelerator):
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.ToTensor(),
+ transforms.Lambda(lambda x: x * 2 - 1),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
+ for image in examples[args.image_column]
+ ]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
+ for image in examples[args.conditioning_image_column]
+ ]
+ conditioning_images = [image_transforms(image) for image in conditioning_images]
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+
+ is_caption_list = isinstance(examples[args.caption_column][0], list)
+ if is_caption_list:
+ examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
+ else:
+ examples["captions"] = list(examples[args.caption_column])
+
+ return examples
+
+ with accelerator.main_process_first():
+ dataset = dataset.with_transform(preprocess_train)
+
+ return dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+ captions = [example["captions"] for example in examples]
+ return {"pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "captions": captions}
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_out_dir = Path(args.output_dir, args.logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
+ if torch.backends.mps.is_available():
+ logger.info("MPS is enabled. Disabling AMP.")
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ # DEBUG, INFO, WARNING, ERROR, CRITICAL
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load models. We will load the text encoders later in a pipeline to compute
+ # embeddings.
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ cogview4_transformer = CogView4Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ logger.info("All models loaded successfully")
+
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="scheduler",
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ if not args.only_target_transformer_blocks:
+ cogview4_transformer.requires_grad_(True)
+ vae.requires_grad_(False)
+
+ # cast down and move to the CPU
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # let's not move the VAE to the GPU yet.
+ vae.to(dtype=torch.float32) # keep the VAE in float32.
+
+ # enable image inputs
+ with torch.no_grad():
+ patch_size = cogview4_transformer.config.patch_size
+ initial_input_channels = cogview4_transformer.config.in_channels * patch_size**2
+ new_linear = torch.nn.Linear(
+ cogview4_transformer.patch_embed.proj.in_features * 2,
+ cogview4_transformer.patch_embed.proj.out_features,
+ bias=cogview4_transformer.patch_embed.proj.bias is not None,
+ dtype=cogview4_transformer.dtype,
+ device=cogview4_transformer.device,
+ )
+ new_linear.weight.zero_()
+ new_linear.weight[:, :initial_input_channels].copy_(cogview4_transformer.patch_embed.proj.weight)
+ if cogview4_transformer.patch_embed.proj.bias is not None:
+ new_linear.bias.copy_(cogview4_transformer.patch_embed.proj.bias)
+ cogview4_transformer.patch_embed.proj = new_linear
+
+ assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0)
+ cogview4_transformer.register_to_config(
+ in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels
+ )
+
+ if args.only_target_transformer_blocks:
+ cogview4_transformer.patch_embed.proj.requires_grad_(True)
+ for name, module in cogview4_transformer.named_modules():
+ if "transformer_blocks" in name:
+ module.requires_grad_(True)
+ else:
+ module.requirs_grad_(False)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))):
+ model = unwrap_model(model)
+ model.save_pretrained(os.path.join(output_dir, "transformer"))
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))):
+ transformer_ = model # noqa: F841
+ else:
+ raise ValueError(f"unexpected save model: {unwrap_model(model).__class__}")
+
+ else:
+ transformer_ = CogView4Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ cogview4_transformer.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimization parameters
+ optimizer = optimizer_class(
+ cogview4_transformer.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Prepare dataset and dataloader.
+ train_dataset = get_train_dataset(args, accelerator)
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+ # Prepare everything with our `accelerator`.
+ cogview4_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ cogview4_transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.
+ text_encoding_pipeline = CogView4ControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype
+ )
+ tokenizer = text_encoding_pipeline.tokenizer
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ logger.info(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
+ logger.info("Logging some dataset samples.")
+ formatted_images = []
+ formatted_control_images = []
+ all_prompts = []
+ for i, batch in enumerate(train_dataloader):
+ images = (batch["pixel_values"] + 1) / 2
+ control_images = (batch["conditioning_pixel_values"] + 1) / 2
+ prompts = batch["captions"]
+
+ if len(formatted_images) > 10:
+ break
+
+ for img, control_img, prompt in zip(images, control_images, prompts):
+ formatted_images.append(img)
+ formatted_control_images.append(control_img)
+ all_prompts.append(prompt)
+
+ logged_artifacts = []
+ for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
+ logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
+ logged_artifacts.append(wandb.Image(img, caption=prompt))
+
+ wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
+ wandb_tracker[0].log({"dataset_samples": logged_artifacts})
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ cogview4_transformer.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(cogview4_transformer):
+ # Convert images to latent space
+ # vae encode
+ prompts = batch["captions"]
+ attention_mask = tokenizer(
+ prompts,
+ padding="longest", # not use max length
+ max_length=args.max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ ).attention_mask.float()
+
+ pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype)
+ control_latents = encode_images(
+ batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
+ )
+ if args.offload:
+ vae.cpu()
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ bsz = pixel_latents.shape[0]
+ noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+
+ # Add noise according for cogview4
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
+ sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device)
+ captions = batch["captions"]
+ image_seq_lens = torch.tensor(
+ pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size**2,
+ dtype=pixel_latents.dtype,
+ device=pixel_latents.device,
+ ) # H * W / VAE patch_size
+ mu = torch.sqrt(image_seq_lens / 256)
+ mu = mu * 0.75 + 0.25
+ scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(
+ dtype=pixel_latents.dtype, device=pixel_latents.device
+ )
+ scale_factors = scale_factors.view(len(batch["captions"]), 1, 1, 1)
+ noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise
+ concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
+ text_encoding_pipeline = text_encoding_pipeline.to("cuda")
+
+ with torch.no_grad():
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ ) = text_encoding_pipeline.encode_prompt(captions, "")
+ original_size = (args.resolution, args.resolution)
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
+
+ target_size = (args.resolution, args.resolution)
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
+
+ target_size = target_size.repeat(len(batch["captions"]), 1)
+ original_size = original_size.repeat(len(batch["captions"]), 1)
+ crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device)
+ crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1)
+
+ # this could be optimized by not having to do any text encoding and just
+ # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
+ if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
+ # Here, we directly pass 16 pad tokens from pooled_prompt_embeds to prompt_embeds.
+ prompt_embeds = pooled_prompt_embeds
+ if args.offload:
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ # Predict.
+ noise_pred_cond = cogview4_transformer(
+ hidden_states=concatenated_noisy_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timesteps,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ return_dict=False,
+ attention_mask=attention_mask,
+ )[0]
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+ # flow-matching loss
+ target = noise - pixel_latents
+
+ weighting = weighting.view(len(batch["captions"]), 1, 1, 1)
+ loss = torch.mean(
+ (weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ params_to_clip = cogview4_transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ cogview4_transformer=cogview4_transformer,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ cogview4_transformer = unwrap_model(cogview4_transformer)
+ if args.upcast_before_saving:
+ cogview4_transformer.to(torch.float32)
+ cogview4_transformer.save_pretrained(args.output_dir)
+
+ del cogview4_transformer
+ del text_encoding_pipeline
+ del vae
+ free_memory()
+
+ # Run a final round of validation.
+ image_logs = None
+ if args.validation_prompt is not None:
+ image_logs = log_validation(
+ cogview4_transformer=None,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*", "checkpoint-*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/community/README.md b/examples/community/README.md
old mode 100755
new mode 100644
index 4f16f65df8fa..9d2452e9177a
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -10,72 +10,82 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
-|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
+|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/)|
+|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
+|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[](https://huggingface.co/spaces/exx8/differential-diffusion) [](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [](https://huggingface.co/spaces/toshas/marigold) [](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see ) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
-| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
-| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
-| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
-| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
-| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
-| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
-| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
-| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
-| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
-| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | - | [Phạm Hồng Vinh](https://github.com/rootonchair) |
-| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
-| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
+| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_interpolation.ipynb) | [Nate Raw](https://github.com/nateraw/) |
+| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_mega.ipynb) | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/long_prompt_weighting_stable_diffusion.ipynb) | [SkyTNT](https://github.com/SkyTNT) |
+| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
+| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
+| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/composable_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
+| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb) | [Mark Rich](https://github.com/MarkRich) |
+| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
+| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
+| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/gluegen_stable_diffusion.ipynb) | [Phạm Hồng Vinh](https://github.com/rootonchair) |
+| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/image_to_image_inpainting_stable_diffusion.ipynb) | [Alex McKinney](https://github.com/vvvm23) |
+| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/text_based_inpainting_stable_dffusion.ipynb) | [Dhruv Karan](https://github.com/unography) |
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
-| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
-| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
-| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
-| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
-| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
-| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) |
-| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
-| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
-| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | - | [Joqsan Azocar](https://github.com/Joqsan) |
-| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint ) | - | [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
+| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_comparison.ipynb) | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
+| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/magic_mix.ipynb) | [Partho Das](https://github.com/daspartho) |
+| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb) | [Ray Wang](https://wrong.wang) |
+| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
+| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_img2img_stable_diffusion.ipynb) | [Nipun Jindal](https://github.com/nipunjindal/) |
+| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/tensorrt_text2image_stable_diffusion_pipeline.ipynb) | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
+| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
+| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint )|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_repaint.ipynb)| [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) |
-| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) |
+| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_images_mixing_with_stable_diffusion.ipynb) | [Karachev Denis](https://github.com/TheDenk) |
| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
-| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
+| Stable Diffusion Mixture Tiling Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SD 1.5](#stable-diffusion-mixture-tiling-pipeline-sd-15) | [](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
+| Stable Diffusion Mixture Canvas Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending. Works by defining a list of Text2Image region objects that detail the region of influence of each diffuser. | [Stable Diffusion Mixture Canvas Pipeline SD 1.5](#stable-diffusion-mixture-canvas-pipeline-sd-15) | [](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
+| Stable Diffusion Mixture Tiling Pipeline SDXL | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SDXL](#stable-diffusion-mixture-tiling-pipeline-sdxl) | [](https://huggingface.co/spaces/elismasilva/mixture-of-diffusers-sdxl-tiling) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
+| Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL | This is an advanced pipeline that leverages ControlNet Tile and Mixture-of-Diffusers techniques, integrating tile diffusion directly into the latent space denoising process. Designed to overcome the limitations of conventional pixel-space tile processing, this pipeline delivers Super Resolution (SR) upscaling for higher-quality images, reduced processing time, and greater adaptability. | [Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL](#stable-diffusion-mod-controlnet-tile-sr-pipeline-sdxl) | [](https://huggingface.co/spaces/elismasilva/mod-control-tile-upscaler-sdxl) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
+| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
-| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
+| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_2_prompt_pipeline.ipynb) | [Umer H. Adil](https://twitter.com/UmerHAdil) |
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
-| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | - | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
+| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/sde_drag.ipynb) | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
-| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#demofusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
-| Instaflow Pipeline | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://arxiv.org/abs/2309.06380) | [Instaflow Pipeline](#instaflow-pipeline) | - | [Ayush Mangal](https://github.com/ayushtues) |
+| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#demofusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/demo_fusion.ipynb) | [Ruoyi Du](https://github.com/RuoyiDu) |
+| Instaflow Pipeline | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://arxiv.org/abs/2309.06380) | [Instaflow Pipeline](#instaflow-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/insta_flow.ipynb) | [Ayush Mangal](https://github.com/ayushtues) |
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#rerender-a-video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
-| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
+| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ip_adapter_face_id.ipynb)| [Fabio Rigano](https://github.com/fabiorigano) |
| InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
| UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |
| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
+PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
-
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
+| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
+| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
+| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://arxiv.org/abs/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -84,6 +94,210 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
## Example usages
+### Spatiotemporal Skip Guidance
+
+**Junha Hyung\*, Kinam Kim\*, Susung Hong, Min-Jung Kim, Jaegul Choo**
+
+**KAIST AI, University of Washington**
+
+[*Spatiotemporal Skip Guidance (STG) for Enhanced Video Diffusion Sampling*](https://arxiv.org/abs/2411.18664) (CVPR 2025) is a simple training-free sampling guidance method for enhancing transformer-based video diffusion models. STG employs an implicit weak model via self-perturbation, avoiding the need for external models or additional training. By selectively skipping spatiotemporal layers, STG produces an aligned, degraded version of the original model to boost sample quality without compromising diversity or dynamic degree.
+
+Following is the example video of STG applied to Mochi.
+
+
+https://github.com/user-attachments/assets/148adb59-da61-4c50-9dfa-425dcb5c23b3
+
+More examples and information can be found on the [GitHub repository](https://github.com/junhahyung/STGuidance) and the [Project website](https://junhahyung.github.io/STGuidance/).
+
+#### Usage example
+```python
+import torch
+from pipeline_stg_mochi import MochiSTGPipeline
+from diffusers.utils import export_to_video
+
+# Load the pipeline
+pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
+
+# Enable memory savings
+pipe = pipe.to("cuda")
+
+#--------Option--------#
+prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
+stg_applied_layers_idx = [34]
+stg_mode = "STG"
+stg_scale = 1.0 # 0.0 for CFG
+#----------------------#
+
+# Generate video frames
+frames = pipe(
+ prompt,
+ height=480,
+ width=480,
+ num_frames=81,
+ stg_applied_layers_idx=stg_applied_layers_idx,
+ stg_scale=stg_scale,
+ generator = torch.Generator().manual_seed(42),
+ do_rescaling=do_rescaling,
+).frames[0]
+
+export_to_video(frames, "output.mp4", fps=30)
+```
+
+### Adaptive Mask Inpainting
+
+**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**
+
+**Seoul National University, Naver Webtoon**
+
+Adaptive Mask Inpainting, presented in the ECCV'24 oral paper [*Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models*](https://snuvclab.github.io/coma), is an algorithm designed to insert humans into scene images without altering the background. Traditional inpainting methods often fail to preserve object geometry and details within the masked region, leading to false affordances. Adaptive Mask Inpainting addresses this issue by progressively specifying the inpainting region over diffusion timesteps, ensuring that the inserted human integrates seamlessly with the existing scene.
+
+Here is the demonstration of Adaptive Mask Inpainting:
+
+
+
+ Your browser does not support the video tag.
+
+
+
+
+
+You can find additional information about Adaptive Mask Inpainting in the [paper](https://arxiv.org/pdf/2401.12978) or in the [project website](https://snuvclab.github.io/coma).
+
+#### Usage example
+First, clone the diffusers github repository, and run the following command to set environment.
+```Shell
+git clone https://github.com/huggingface/diffusers.git
+cd diffusers
+
+conda create --name ami python=3.9 -y
+conda activate ami
+
+conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y
+python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
+pip install easydict
+pip install diffusers==0.20.2 accelerate safetensors transformers
+pip install setuptools==59.5.0
+pip install opencv-python
+pip install numpy==1.24.1
+```
+Then, run the below code under 'diffusers' directory.
+```python
+import numpy as np
+import torch
+from PIL import Image
+
+from diffusers import DDIMScheduler
+from diffusers import DiffusionPipeline
+from diffusers.utils import load_image
+
+from examples.community.adaptive_mask_inpainting import download_file, AdaptiveMaskInpaintPipeline, AMI_INSTALL_MESSAGE
+
+print(AMI_INSTALL_MESSAGE)
+
+from easydict import EasyDict
+
+
+
+if __name__ == "__main__":
+ """
+ Download Necessary Files
+ """
+ download_file(
+ url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/model_final_edd263.pkl?download=true",
+ output_file = "model_final_edd263.pkl",
+ exist_ok=True,
+ )
+ download_file(
+ url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/pointrend_rcnn_R_50_FPN_3x_coco.yaml?download=true",
+ output_file = "pointrend_rcnn_R_50_FPN_3x_coco.yaml",
+ exist_ok=True,
+ )
+ download_file(
+ url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_img.png?download=true",
+ output_file = "input_img.png",
+ exist_ok=True,
+ )
+ download_file(
+ url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_mask.png?download=true",
+ output_file = "input_mask.png",
+ exist_ok=True,
+ )
+ download_file(
+ url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-PointRend-RCNN-FPN.yaml?download=true",
+ output_file = "Base-PointRend-RCNN-FPN.yaml",
+ exist_ok=True,
+ )
+ download_file(
+ url = "https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-RCNN-FPN.yaml?download=true",
+ output_file = "Base-RCNN-FPN.yaml",
+ exist_ok=True,
+ )
+
+ """
+ Prepare Adaptive Mask Inpainting Pipeline
+ """
+ # device
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ num_steps = 50
+
+ # Scheduler
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False
+ )
+ scheduler.set_timesteps(num_inference_steps=num_steps)
+
+ ## load models as pipelines
+ pipeline = AdaptiveMaskInpaintPipeline.from_pretrained(
+ "Uminosachi/realisticVisionV51_v51VAE-inpainting",
+ scheduler=scheduler,
+ torch_dtype=torch.float16,
+ requires_safety_checker=False
+ ).to(device)
+
+ ## disable safety checker
+ enable_safety_checker = False
+ if not enable_safety_checker:
+ pipeline.safety_checker = None
+
+ """
+ Run Adaptive Mask Inpainting
+ """
+ default_mask_image = Image.open("./input_mask.png").convert("L")
+ init_image = Image.open("./input_img.png").convert("RGB")
+
+
+ seed = 59
+ generator = torch.Generator(device=device)
+ generator.manual_seed(seed)
+
+ image = pipeline(
+ prompt="a man sitting on a couch",
+ negative_prompt="worst quality, normal quality, low quality, bad anatomy, artifacts, blurry, cropped, watermark, greyscale, nsfw",
+ image=init_image,
+ default_mask_image=default_mask_image,
+ guidance_scale=11.0,
+ strength=0.98,
+ use_adaptive_mask=True,
+ generator=generator,
+ enforce_full_mask_ratio=0.0,
+ visualization_save_dir="./ECCV2024_adaptive_mask_inpainting_demo", # DON'T CHANGE THIS!!!
+ human_detection_thres=0.015,
+ ).images[0]
+
+
+ image.save(f'final_img.png')
+```
+#### [Troubleshooting]
+
+If you run into an error `cannot import name 'cached_download' from 'huggingface_hub'` (issue [1851](https://github.com/easydiffusion/easydiffusion/issues/1851)), remove `cached_download` from the import line in the file `diffusers/utils/dynamic_modules_utils.py`.
+
+For example, change the import line from `.../env/lib/python3.8/site-packages/diffusers/utils/dynamic_modules_utils.py`.
+
+
### Flux with CFG
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
@@ -94,24 +308,30 @@ Example usage:
from diffusers import DiffusionPipeline
import torch
+model_name = "black-forest-labs/FLUX.1-dev"
+prompt = "a watercolor painting of a unicorn"
+negative_prompt = "pink"
+
+# Load the diffusion pipeline
pipeline = DiffusionPipeline.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
+ model_name,
torch_dtype=torch.bfloat16,
custom_pipeline="pipeline_flux_with_cfg"
)
pipeline.enable_model_cpu_offload()
-prompt = "a watercolor painting of a unicorn"
-negative_prompt = "pink"
+# Generate the image
img = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
true_cfg=1.5,
guidance_scale=3.5,
- num_images_per_prompt=1,
generator=torch.manual_seed(0)
).images[0]
+
+# Save the generated image
img.save("cfg_flux.png")
+print("Image generated and saved successfully.")
```
### Differential Diffusion
@@ -684,6 +904,8 @@ out = pipe(
wildcard_files=["object.txt", "animal.txt"],
num_prompt_samples=1
)
+out.images[0].save("image.png")
+torch.cuda.empty_cache()
```
### Composable Stable diffusion
@@ -732,6 +954,7 @@ for i in range(args.num_images):
images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
+print("Image saved successfully!")
```
### Imagic Stable Diffusion
@@ -782,10 +1005,15 @@ image.save('./imagic/imagic_image_alpha_2.png')
Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
```python
+import os
import torch as th
import numpy as np
from diffusers import DiffusionPipeline
+# Ensure the save directory exists or create it
+save_dir = './seed_resize/'
+os.makedirs(save_dir, exist_ok=True)
+
has_cuda = th.cuda.is_available()
device = th.device('cpu' if not has_cuda else 'cuda')
@@ -799,7 +1027,6 @@ def dummy(images, **kwargs):
pipe.safety_checker = dummy
-
images = []
th.manual_seed(0)
generator = th.Generator("cuda").manual_seed(0)
@@ -818,15 +1045,14 @@ res = pipe(
width=width,
generator=generator)
image = res.images[0]
-image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
-
+image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image.png'.format(w=width, h=height)))
th.manual_seed(0)
generator = th.Generator("cuda").manual_seed(0)
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
- custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
+ custom_pipeline="seed_resize_stable_diffusion"
).to(device)
width = 512
@@ -840,11 +1066,11 @@ res = pipe(
width=width,
generator=generator)
image = res.images[0]
-image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
+image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image.png'.format(w=width, h=height)))
pipe_compare = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
- custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
+ custom_pipeline="seed_resize_stable_diffusion"
).to(device)
res = pipe_compare(
@@ -857,7 +1083,7 @@ res = pipe_compare(
)
image = res.images[0]
-image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
+image.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height)))
```
### Multilingual Stable Diffusion Pipeline
@@ -934,38 +1160,100 @@ GlueGen is a minimal adapter that allows alignment between any encoder (Text Enc
Make sure you downloaded `gluenet_French_clip_overnorm_over3_noln.ckpt` for French (there are also pre-trained weights for Chinese, Italian, Japanese, Spanish or train your own) at [GlueGen's official repo](https://github.com/salesforce/GlueGen/tree/main).
```python
-from PIL import Image
-
+import os
+import gc
+import urllib.request
import torch
+from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM, CLIPTokenizer, CLIPTextModel
+from diffusers import DiffusionPipeline
-from transformers import AutoModel, AutoTokenizer
+# Download checkpoints
+CHECKPOINTS = [
+ "https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Chinese_clip_overnorm_over3_noln.ckpt",
+ "https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_French_clip_overnorm_over3_noln.ckpt",
+ "https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Italian_clip_overnorm_over3_noln.ckpt",
+ "https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Japanese_clip_overnorm_over3_noln.ckpt",
+ "https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Spanish_clip_overnorm_over3_noln.ckpt",
+ "https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_sound2img_audioclip_us8k.ckpt"
+]
-from diffusers import DiffusionPipeline
+LANGUAGE_PROMPTS = {
+ "French": "une voiture sur la plage",
+ #"Chinese": "海滩上的一辆车",
+ #"Italian": "una macchina sulla spiaggia",
+ #"Japanese": "浜辺の車",
+ #"Spanish": "un coche en la playa"
+}
-if __name__ == "__main__":
- device = "cuda"
+def download_checkpoints(checkpoint_dir):
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ for url in CHECKPOINTS:
+ filename = os.path.join(checkpoint_dir, os.path.basename(url))
+ if not os.path.exists(filename):
+ print(f"Downloading {filename}...")
+ urllib.request.urlretrieve(url, filename)
+ print(f"Downloaded {filename}")
+ else:
+ print(f"Checkpoint {filename} already exists, skipping download.")
+ return checkpoint_dir
+
+def load_checkpoint(pipeline, checkpoint_path, device):
+ state_dict = torch.load(checkpoint_path, map_location=device)
+ state_dict = state_dict.get("state_dict", state_dict)
+ missing_keys, unexpected_keys = pipeline.unet.load_state_dict(state_dict, strict=False)
+ return pipeline
+
+def generate_image(pipeline, prompt, device, output_path):
+ with torch.inference_mode():
+ image = pipeline(
+ prompt,
+ generator=torch.Generator(device=device).manual_seed(42),
+ num_inference_steps=50
+ ).images[0]
+ image.save(output_path)
+ print(f"Image saved to {output_path}")
+
+checkpoint_dir = download_checkpoints("./checkpoints_all/gluenet_checkpoint")
+device = "cuda" if torch.cuda.is_available() else "cpu"
+print(f"Using device: {device}")
- lm_model_id = "xlm-roberta-large"
- token_max_length = 77
+tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
+model = XLMRobertaForMaskedLM.from_pretrained("xlm-roberta-base").to(device)
+inputs = tokenizer("Ceci est une phrase incomplète avec un [MASK].", return_tensors="pt").to(device)
+with torch.inference_mode():
+ _ = model(**inputs)
- text_encoder = AutoModel.from_pretrained(lm_model_id)
- tokenizer = AutoTokenizer.from_pretrained(lm_model_id, model_max_length=token_max_length, use_fast=False)
- tensor_norm = torch.Tensor([[43.8203],[28.3668],[27.9345],[28.0084],[28.2958],[28.2576],[28.3373],[28.2695],[28.4097],[28.2790],[28.2825],[28.2807],[28.2775],[28.2708],[28.2682],[28.2624],[28.2589],[28.2611],[28.2616],[28.2639],[28.2613],[28.2566],[28.2615],[28.2665],[28.2799],[28.2885],[28.2852],[28.2863],[28.2780],[28.2818],[28.2764],[28.2532],[28.2412],[28.2336],[28.2514],[28.2734],[28.2763],[28.2977],[28.2971],[28.2948],[28.2818],[28.2676],[28.2831],[28.2890],[28.2979],[28.2999],[28.3117],[28.3363],[28.3554],[28.3626],[28.3589],[28.3597],[28.3543],[28.3660],[28.3731],[28.3717],[28.3812],[28.3753],[28.3810],[28.3777],[28.3693],[28.3713],[28.3670],[28.3691],[28.3679],[28.3624],[28.3703],[28.3703],[28.3720],[28.3594],[28.3576],[28.3562],[28.3438],[28.3376],[28.3389],[28.3433],[28.3191]])
+clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
- pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- text_encoder=text_encoder,
- tokenizer=tokenizer,
- custom_pipeline="gluegen"
- ).to(device)
- pipeline.load_language_adapter("gluenet_French_clip_overnorm_over3_noln.ckpt", num_token=token_max_length, dim=1024, dim_out=768, tensor_norm=tensor_norm)
+# Initialize pipeline
+pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ text_encoder=clip_text_encoder,
+ tokenizer=clip_tokenizer,
+ custom_pipeline="gluegen",
+ safety_checker=None
+).to(device)
+
+os.makedirs("outputs", exist_ok=True)
+
+# Generate images
+for language, prompt in LANGUAGE_PROMPTS.items():
- prompt = "une voiture sur la plage"
+ checkpoint_file = f"gluenet_{language}_clip_overnorm_over3_noln.ckpt"
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
+ try:
+ pipeline = load_checkpoint(pipeline, checkpoint_path, device)
+ output_path = f"outputs/gluegen_output_{language.lower()}.png"
+ generate_image(pipeline, prompt, device, output_path)
+ except Exception as e:
+ print(f"Error processing {language} model: {e}")
+ continue
- generator = torch.Generator(device=device).manual_seed(42)
- image = pipeline(prompt, generator=generator).images[0]
- image.save("gluegen_output_fr.png")
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
```
Which will produce:
@@ -982,28 +1270,39 @@ The aim is to overlay two images, then mask out the boundary between `image` and
For example, this could be used to place a logo on a shirt and make it blend seamlessly.
```python
-import PIL
import torch
-
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import DiffusionPipeline
-image_path = "./path-to-image.png"
-inner_image_path = "./path-to-inner-image.png"
-mask_path = "./path-to-mask.png"
+image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+inner_image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+def load_image(url, mode="RGB"):
+ response = requests.get(url)
+ if response.status_code == 200:
+ return Image.open(BytesIO(response.content)).convert(mode).resize((512, 512))
+ else:
+ raise FileNotFoundError(f"Could not retrieve image from {url}")
+
-init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
-inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
-mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
+init_image = load_image(image_url, mode="RGB")
+inner_image = load_image(inner_image_url, mode="RGBA")
+mask_image = load_image(mask_url, mode="RGB")
pipe = DiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-inpainting",
+ "stable-diffusion-v1-5/stable-diffusion-inpainting",
custom_pipeline="img2img_inpainting",
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
-prompt = "Your prompt here!"
+prompt = "a mecha robot sitting on a bench"
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
+
+image.save("output.png")
```

@@ -1016,28 +1315,49 @@ Currently uses the CLIPSeg model for mask generation, then calls the standard St
```python
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import DiffusionPipeline
-
from PIL import Image
import requests
+import torch
+# Load CLIPSeg model and processor
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
-model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
+model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to("cuda")
+# Load Stable Diffusion Inpainting Pipeline with custom pipeline
pipe = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
custom_pipeline="text_inpainting",
segmentation_model=model,
segmentation_processor=processor
-)
-pipe = pipe.to("cuda")
-
+).to("cuda")
+# Load input image
url = "https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true"
-image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
-text = "a glass" # will mask out this text
-prompt = "a cup" # the masked out region will be replaced with this
+image = Image.open(requests.get(url, stream=True).raw)
+
+# Step 1: Resize input image for CLIPSeg (224x224)
+segmentation_input = image.resize((224, 224))
-image = pipe(image=image, text=text, prompt=prompt).images[0]
+# Step 2: Generate segmentation mask
+text = "a glass" # Object to mask
+inputs = processor(text=text, images=segmentation_input, return_tensors="pt").to("cuda")
+
+with torch.no_grad():
+ mask = model(**inputs).logits.sigmoid() # Get segmentation mask
+
+# Resize mask back to 512x512 for SD inpainting
+mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(512, 512), mode="bilinear").squeeze(0)
+
+# Step 3: Resize input image for Stable Diffusion
+image = image.resize((512, 512))
+
+# Step 4: Run inpainting with Stable Diffusion
+prompt = "a cup" # The masked-out region will be replaced with this
+result = pipe(image=image, mask=mask, prompt=prompt,text=text).images[0]
+
+# Save output
+result.save("inpainting_output.png")
+print("Inpainting completed. Image saved as 'inpainting_output.png'.")
```
### Bit Diffusion
@@ -1213,8 +1533,10 @@ There are 3 parameters for the method-
Here is an example usage-
```python
+import requests
from diffusers import DiffusionPipeline, DDIMScheduler
from PIL import Image
+from io import BytesIO
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
@@ -1222,9 +1544,11 @@ pipe = DiffusionPipeline.from_pretrained(
scheduler=DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
).to('cuda')
-img = Image.open('phone.jpg')
+url = "https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg"
+response = requests.get(url)
+image = Image.open(BytesIO(response.content)).convert("RGB") # Convert to RGB to avoid issues
mix_img = pipe(
- img,
+ image,
prompt='bed',
kmin=0.3,
kmax=0.5,
@@ -1377,6 +1701,8 @@ This Diffusion Pipeline takes two images or an image_embeddings tensor of size 2
import torch
from diffusers import DiffusionPipeline
from PIL import Image
+import requests
+from io import BytesIO
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16
@@ -1388,13 +1714,25 @@ pipe = DiffusionPipeline.from_pretrained(
)
pipe.to(device)
-images = [Image.open('./starry_night.jpg'), Image.open('./flowers.jpg')]
+# List of image URLs
+image_urls = [
+ 'https://camo.githubusercontent.com/ef13c8059b12947c0d5e8d3ea88900de6bf1cd76bbf61ace3928e824c491290e/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f4e616761536169416268696e61792f556e434c4950496d616765496e746572706f6c6174696f6e53616d706c65732f7265736f6c76652f6d61696e2f7374617272795f6e696768742e6a7067',
+ 'https://camo.githubusercontent.com/d1947ab7c49ae3f550c28409d5e8b120df48e456559cf4557306c0848337702c/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f4e616761536169416268696e61792f556e434c4950496d616765496e746572706f6c6174696f6e53616d706c65732f7265736f6c76652f6d61696e2f666c6f776572732e6a7067'
+]
+
+# Open images from URLs
+images = []
+for url in image_urls:
+ response = requests.get(url)
+ img = Image.open(BytesIO(response.content))
+ images.append(img)
+
# For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
generator = torch.Generator(device=device).manual_seed(42)
output = pipe(image=images, steps=6, generator=generator)
-for i,image in enumerate(output.images):
+for i, image in enumerate(output.images):
image.save('starry_to_flowers_%s.jpg' % i)
```
@@ -1471,37 +1809,51 @@ from diffusers import DiffusionPipeline
from PIL import Image
from transformers import CLIPImageProcessor, CLIPModel
+# Load CLIP model and feature extractor
feature_extractor = CLIPImageProcessor.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
)
clip_model = CLIPModel.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
)
+
+# Load guided pipeline
guided_pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
- # custom_pipeline="clip_guided_stable_diffusion",
- custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",
+ custom_pipeline="clip_guided_stable_diffusion_img2img",
clip_model=clip_model,
feature_extractor=feature_extractor,
torch_dtype=torch.float16,
)
guided_pipeline.enable_attention_slicing()
guided_pipeline = guided_pipeline.to("cuda")
+
+# Define prompt and fetch image
prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+edit_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+# Run the pipeline
image = guided_pipeline(
prompt=prompt,
- num_inference_steps=30,
- image=init_image,
- strength=0.75,
- guidance_scale=7.5,
- clip_guidance_scale=100,
- num_cutouts=4,
- use_cutouts=False,
+ height=512, # Height of the output image
+ width=512, # Width of the output image
+ image=edit_image, # Input image to guide the diffusion
+ strength=0.75, # How much to transform the input image
+ num_inference_steps=30, # Number of diffusion steps
+ guidance_scale=7.5, # Scale of the classifier-free guidance
+ clip_guidance_scale=100, # Scale of the CLIP guidance
+ num_images_per_prompt=1, # Generate one image per prompt
+ eta=0.0, # Noise scheduling parameter
+ num_cutouts=4, # Number of cutouts for CLIP guidance
+ use_cutouts=False, # Whether to use cutouts
+ output_type="pil", # Output as PIL image
).images[0]
-display(image)
+
+# Display the generated image
+image.show()
+
```
Init Image
@@ -2078,81 +2430,15 @@ CLIP guided stable diffusion images mixing pipeline allows to combine two images
This approach is using (optional) CoCa model to avoid writing image description.
[More code examples](https://github.com/TheDenk/images_mixing)
-### Stable Diffusion XL Long Weighted Prompt Pipeline
-
-This SDXL pipeline supports unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
-
-You can provide both `prompt` and `prompt_2`. If only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
-
-```python
-from diffusers import DiffusionPipeline
-from diffusers.utils import load_image
-import torch
-
-pipe = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0"
- , torch_dtype = torch.float16
- , use_safetensors = True
- , variant = "fp16"
- , custom_pipeline = "lpw_stable_diffusion_xl",
-)
-
-prompt = "photo of a cute (white) cat running on the grass" * 20
-prompt2 = "chasing (birds:1.5)" * 20
-prompt = f"{prompt},{prompt2}"
-neg_prompt = "blur, low quality, carton, animate"
-
-pipe.to("cuda")
-
-# text2img
-t2i_images = pipe(
- prompt=prompt,
- negative_prompt=neg_prompt,
-).images # alternatively, you can call the .text2img() function
-
-# img2img
-input_image = load_image("/path/to/local/image.png") # or URL to your input image
-i2i_images = pipe.img2img(
- prompt=prompt,
- negative_prompt=neg_prompt,
- image=input_image,
- strength=0.8, # higher strength will result in more variation compared to original image
-).images
-
-# inpaint
-input_mask = load_image("/path/to/local/mask.png") # or URL to your input inpainting mask
-inpaint_images = pipe.inpaint(
- prompt="photo of a cute (black) cat running on the grass" * 20,
- negative_prompt=neg_prompt,
- image=input_image,
- mask=input_mask,
- strength=0.6, # higher strength will result in more variation compared to original image
-).images
-
-pipe.to("cpu")
-torch.cuda.empty_cache()
-
-from IPython.display import display # assuming you are using this code in a notebook
-display(t2i_images[0])
-display(i2i_images[0])
-display(inpaint_images[0])
-```
-
-In the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. "birds" are showing up in the result.
-
-
-For more results, checkout [PR #6114](https://github.com/huggingface/diffusers/pull/6114).
-
### Example Images Mixing (with CoCa)
```python
-import requests
-from io import BytesIO
-
import PIL
import torch
+import requests
import open_clip
from open_clip import SimpleTokenizer
+from io import BytesIO
from diffusers import DiffusionPipeline
from transformers import CLIPImageProcessor, CLIPModel
@@ -2215,11 +2501,80 @@ pipe_images = mixing_pipeline(
clip_guidance_scale=100,
generator=generator,
).images
+
+output_path = "mixed_output.jpg"
+pipe_images[0].save(output_path)
+print(f"Image saved successfully at {output_path}")
```

-### Stable Diffusion Mixture Tiling
+### Stable Diffusion XL Long Weighted Prompt Pipeline
+
+This SDXL pipeline supports unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
+
+You can provide both `prompt` and `prompt_2`. If only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
+
+```python
+from diffusers import DiffusionPipeline
+from diffusers.utils import load_image
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0"
+ , torch_dtype = torch.float16
+ , use_safetensors = True
+ , variant = "fp16"
+ , custom_pipeline = "lpw_stable_diffusion_xl",
+)
+
+prompt = "photo of a cute (white) cat running on the grass" * 20
+prompt2 = "chasing (birds:1.5)" * 20
+prompt = f"{prompt},{prompt2}"
+neg_prompt = "blur, low quality, carton, animate"
+
+pipe.to("cuda")
+
+# text2img
+t2i_images = pipe(
+ prompt=prompt,
+ negative_prompt=neg_prompt,
+).images # alternatively, you can call the .text2img() function
+
+# img2img
+input_image = load_image("/path/to/local/image.png") # or URL to your input image
+i2i_images = pipe.img2img(
+ prompt=prompt,
+ negative_prompt=neg_prompt,
+ image=input_image,
+ strength=0.8, # higher strength will result in more variation compared to original image
+).images
+
+# inpaint
+input_mask = load_image("/path/to/local/mask.png") # or URL to your input inpainting mask
+inpaint_images = pipe.inpaint(
+ prompt="photo of a cute (black) cat running on the grass" * 20,
+ negative_prompt=neg_prompt,
+ image=input_image,
+ mask=input_mask,
+ strength=0.6, # higher strength will result in more variation compared to original image
+).images
+
+pipe.to("cpu")
+torch.cuda.empty_cache()
+
+from IPython.display import display # assuming you are using this code in a notebook
+display(t2i_images[0])
+display(i2i_images[0])
+display(inpaint_images[0])
+```
+
+In the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. "birds" are showing up in the result.
+
+
+For more results, checkout [PR #6114](https://github.com/huggingface/diffusers/pull/6114).
+
+### Stable Diffusion Mixture Tiling Pipeline SD 1.5
This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
@@ -2246,9 +2601,195 @@ image = pipeline(
seed=7178915308,
num_inference_steps=50,
)["images"][0]
-```
-
-
+```
+
+
+
+### Stable Diffusion Mixture Canvas Pipeline SD 1.5
+
+This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
+
+```python
+from PIL import Image
+from diffusers import LMSDiscreteScheduler, DiffusionPipeline
+from diffusers.pipelines.pipeline_utils import Image2ImageRegion, Text2ImageRegion, preprocess_image
+
+
+# Load and preprocess guide image
+iic_image = preprocess_image(Image.open("input_image.png").convert("RGB"))
+
+# Create scheduler and model (similar to StableDiffusionPipeline)
+scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler).to("cuda:0", custom_pipeline="mixture_canvas")
+pipeline.to("cuda")
+
+# Mixture of Diffusers generation
+output = pipeline(
+ canvas_height=800,
+ canvas_width=352,
+ regions=[
+ Text2ImageRegion(0, 800, 0, 352, guidance_scale=8,
+ prompt=f"best quality, masterpiece, WLOP, sakimichan, art contest winner on pixiv, 8K, intricate details, wet effects, rain drops, ethereal, mysterious, futuristic, UHD, HDR, cinematic lighting, in a beautiful forest, rainy day, award winning, trending on artstation, beautiful confident cheerful young woman, wearing a futuristic sleeveless dress, ultra beautiful detailed eyes, hyper-detailed face, complex, perfect, model, textured, chiaroscuro, professional make-up, realistic, figure in frame, "),
+ Image2ImageRegion(352-800, 352, 0, 352, reference_image=iic_image, strength=1.0),
+ ],
+ num_inference_steps=100,
+ seed=5525475061,
+)["images"][0]
+```
+
+
+
+
+### Stable Diffusion Mixture Tiling Pipeline SDXL
+
+This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
+
+```python
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, AutoencoderKL
+
+device="cuda"
+
+# Load fixed vae (optional)
+vae = AutoencoderKL.from_pretrained(
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
+).to(device)
+
+# Create scheduler and model (similar to StableDiffusionPipeline)
+model_id="stablediffusionapi/yamermix-v8-vae"
+scheduler = DPMSolverMultistepScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+pipe = DiffusionPipeline.from_pretrained(
+ model_id,
+ torch_dtype=torch.float16,
+ vae=vae,
+ custom_pipeline="mixture_tiling_sdxl",
+ scheduler=scheduler,
+ use_safetensors=False
+).to(device)
+
+pipe.enable_model_cpu_offload()
+pipe.enable_vae_tiling()
+pipe.enable_vae_slicing()
+
+generator = torch.Generator(device).manual_seed(297984183)
+
+# Mixture of Diffusers generation
+image = pipe(
+ prompt=[[
+ "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
+ ]],
+ tile_height=1024,
+ tile_width=1280,
+ tile_row_overlap=0,
+ tile_col_overlap=256,
+ guidance_scale_tiles=[[7, 7, 7]], # or guidance_scale=7 if is the same for all prompts
+ height=1024,
+ width=3840,
+ generator=generator,
+ num_inference_steps=30,
+)["images"][0]
+```
+
+
+
+### Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL
+
+This pipeline implements the [MoD (Mixture-of-Diffusers)]("https://arxiv.org/pdf/2408.06072") tiled diffusion technique and combines it with SDXL's ControlNet Tile process to generate SR images.
+
+This works better with 4x scales, but you can try adjusts parameters to higher scales.
+
+````python
+import torch
+from diffusers import DiffusionPipeline, ControlNetUnionModel, AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel
+from diffusers.utils import load_image
+from PIL import Image
+
+device = "cuda"
+
+# Initialize the models and pipeline
+controlnet = ControlNetUnionModel.from_pretrained(
+ "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
+).to(device=device)
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
+
+model_id = "SG161222/RealVisXL_V5.0"
+pipe = DiffusionPipeline.from_pretrained(
+ model_id,
+ torch_dtype=torch.float16,
+ vae=vae,
+ controlnet=controlnet,
+ custom_pipeline="mod_controlnet_tile_sr_sdxl",
+ use_safetensors=True,
+ variant="fp16",
+).to(device)
+
+unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
+
+#pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
+pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
+pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
+
+# Set selected scheduler
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+# Load image
+control_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1.jpg")
+original_height = control_image.height
+original_width = control_image.width
+print(f"Current resolution: H:{original_height} x W:{original_width}")
+
+# Pre-upscale image for tiling
+resolution = 4096
+tile_gaussian_sigma = 0.3
+max_tile_size = 1024 # or 1280
+
+current_size = max(control_image.size)
+scale_factor = max(2, resolution / current_size)
+new_size = (int(control_image.width * scale_factor), int(control_image.height * scale_factor))
+image = control_image.resize(new_size, Image.LANCZOS)
+
+# Update target height and width
+target_height = image.height
+target_width = image.width
+print(f"Target resolution: H:{target_height} x W:{target_width}")
+
+# Calculate overlap size
+normal_tile_overlap, border_tile_overlap = pipe.calculate_overlap(target_width, target_height)
+
+# Set other params
+tile_weighting_method = pipe.TileWeightingMethod.COSINE.value
+guidance_scale = 4
+num_inference_steps = 35
+denoising_strenght = 0.65
+controlnet_strength = 1.0
+prompt = "high-quality, noise-free edges, high quality, 4k, hd, 8k"
+negative_prompt = "blurry, pixelated, noisy, low resolution, artifacts, poor details"
+
+# Image generation
+generated_image = pipe(
+ image=image,
+ control_image=control_image,
+ control_mode=[6],
+ controlnet_conditioning_scale=float(controlnet_strength),
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ normal_tile_overlap=normal_tile_overlap,
+ border_tile_overlap=border_tile_overlap,
+ height=target_height,
+ width=target_width,
+ original_size=(original_width, original_height),
+ target_size=(target_width, target_height),
+ guidance_scale=guidance_scale,
+ strength=float(denoising_strenght),
+ tile_weighting_method=tile_weighting_method,
+ max_tile_size=max_tile_size,
+ tile_gaussian_sigma=float(tile_gaussian_sigma),
+ num_inference_steps=num_inference_steps,
+)["images"][0]
+````
+
### TensorRT Inpainting Stable Diffusion Pipeline
@@ -2292,41 +2833,6 @@ image = pipe(prompt, image=input_image, mask_image=mask_image, strength=0.75,).i
image.save('tensorrt_inpaint_mecha_robot.png')
```
-### Stable Diffusion Mixture Canvas
-
-This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
-
-```python
-from PIL import Image
-from diffusers import LMSDiscreteScheduler, DiffusionPipeline
-from diffusers.pipelines.pipeline_utils import Image2ImageRegion, Text2ImageRegion, preprocess_image
-
-
-# Load and preprocess guide image
-iic_image = preprocess_image(Image.open("input_image.png").convert("RGB"))
-
-# Create scheduler and model (similar to StableDiffusionPipeline)
-scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
-pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler).to("cuda:0", custom_pipeline="mixture_canvas")
-pipeline.to("cuda")
-
-# Mixture of Diffusers generation
-output = pipeline(
- canvas_height=800,
- canvas_width=352,
- regions=[
- Text2ImageRegion(0, 800, 0, 352, guidance_scale=8,
- prompt=f"best quality, masterpiece, WLOP, sakimichan, art contest winner on pixiv, 8K, intricate details, wet effects, rain drops, ethereal, mysterious, futuristic, UHD, HDR, cinematic lighting, in a beautiful forest, rainy day, award winning, trending on artstation, beautiful confident cheerful young woman, wearing a futuristic sleeveless dress, ultra beautiful detailed eyes, hyper-detailed face, complex, perfect, model, textured, chiaroscuro, professional make-up, realistic, figure in frame, "),
- Image2ImageRegion(352-800, 352, 0, 352, reference_image=iic_image, strength=1.0),
- ],
- num_inference_steps=100,
- seed=5525475061,
-)["images"][0]
-```
-
-
-
-
### IADB pipeline
This pipeline is the implementation of the [α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) paper.
@@ -2460,16 +2966,17 @@ for obj in range(bs):
### Stable Diffusion XL Reference
-This pipeline uses the Reference. Refer to the [stable_diffusion_reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference).
+This pipeline uses the Reference. Refer to the [Stable Diffusion Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference) section for more information.
```py
import torch
-from PIL import Image
+# from diffusers import DiffusionPipeline
from diffusers.utils import load_image
-from diffusers import DiffusionPipeline
from diffusers.schedulers import UniPCMultistepScheduler
-input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+from .stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
+
+input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg")
# pipe = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-base-1.0",
@@ -2487,7 +2994,7 @@ pipe = StableDiffusionXLReferencePipeline.from_pretrained(
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
result_img = pipe(ref_image=input_image,
- prompt="1girl",
+ prompt="a dog",
num_inference_steps=20,
reference_attn=True,
reference_adain=True).images[0]
@@ -2495,14 +3002,14 @@ result_img = pipe(ref_image=input_image,
Reference Image
-
+
Output Image
-`prompt: 1 girl`
+`prompt: a dog`
-`reference_attn=True, reference_adain=True, num_inference_steps=20`
-
+`reference_attn=False, reference_adain=True, num_inference_steps=20`
+
Reference Image

@@ -2524,6 +3031,88 @@ Output Image
`reference_attn=True, reference_adain=True, num_inference_steps=20`

+### Stable Diffusion XL ControlNet Reference
+
+This pipeline uses the Reference Control and with ControlNet. Refer to the [Stable Diffusion ControlNet Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-controlnet-reference) and [Stable Diffusion XL Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-xl-reference) sections for more information.
+
+```py
+from diffusers import ControlNetModel, AutoencoderKL
+from diffusers.schedulers import UniPCMultistepScheduler
+from diffusers.utils import load_image
+import numpy as np
+import torch
+
+import cv2
+from PIL import Image
+
+from .stable_diffusion_xl_controlnet_reference import StableDiffusionXLControlNetReferencePipeline
+
+# download an image
+canny_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg"
+)
+
+ref_image = load_image(
+ "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
+)
+
+# initialize the models and pipeline
+controlnet_conditioning_scale = 0.5 # recommended for good generalization
+controlnet = ControlNetModel.from_pretrained(
+ "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
+)
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+pipe = StableDiffusionXLControlNetReferencePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
+).to("cuda:0")
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+# get canny image
+image = np.array(canny_image)
+image = cv2.Canny(image, 100, 200)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+
+# generate image
+image = pipe(
+ prompt="a cat",
+ num_inference_steps=20,
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
+ image=canny_image,
+ ref_image=ref_image,
+ reference_attn=False,
+ reference_adain=True,
+ style_fidelity=1.0,
+ generator=torch.Generator("cuda").manual_seed(42)
+).images[0]
+```
+
+Canny ControlNet Image
+
+
+
+Reference Image
+
+
+
+Output Image
+
+`prompt: a cat`
+
+`reference_attn=True, reference_adain=True, num_inference_steps=20, style_fidelity=1.0`
+
+
+
+`reference_attn=False, reference_adain=True, num_inference_steps=20, style_fidelity=1.0`
+
+
+
+`reference_attn=True, reference_adain=False, num_inference_steps=20, style_fidelity=1.0`
+
+
+
### Stable diffusion fabric pipeline
FABRIC approach applicable to a wide range of popular diffusion models, which exploits
@@ -2675,14 +3264,19 @@ Here's a full example for `ReplaceEdit``:
```python
import torch
-import numpy as np
-import matplotlib.pyplot as plt
from diffusers import DiffusionPipeline
+import numpy as np
+from PIL import Image
-pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="pipeline_prompt2prompt"
+).to("cuda")
-prompts = ["A turtle playing with a ball",
- "A monkey playing with a ball"]
+prompts = [
+ "A turtle playing with a ball",
+ "A monkey playing with a ball"
+]
cross_attention_kwargs = {
"edit_type": "replace",
@@ -2690,7 +3284,15 @@ cross_attention_kwargs = {
"self_replace_steps": 0.4
}
-outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
+outputs = pipe(
+ prompt=prompts,
+ height=512,
+ width=512,
+ num_inference_steps=50,
+ cross_attention_kwargs=cross_attention_kwargs
+)
+
+outputs.images[0].save("output_image_0.png")
```
And abbreviated examples for the other edits:
@@ -3219,6 +3821,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK
best quality, 3persons in garden, an old man red suit
```
+### Use base prompt
+
+You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.
+
+```
+2d animation style ADDBASE
+masterpiece, high quality ADDCOMM
+(blue sky)++ BREAK
+green hair twintail BREAK
+book shelf BREAK
+messy desk BREAK
+orange++ dress and sofa
+```
+
### Negative prompt
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
@@ -3249,6 +3865,7 @@ pipe(prompt=prompt, rp_args=rp_args)
### Optional Parameters
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
+- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
@@ -3577,6 +4194,7 @@ The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
```py
from diffusers import DiffusionPipeline
+import torch
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
@@ -3644,33 +4262,89 @@ This pipeline provides drag-and-drop image editing using stochastic differential
See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information.
```py
-import PIL
import torch
from diffusers import DDIMScheduler, DiffusionPipeline
+from PIL import Image
+import requests
+from io import BytesIO
+import numpy as np
# Load the pipeline
model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
-pipe.to('cuda')
-# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
-# If not training LoRA, please avoid using torch.float16
-# pipe.to(torch.float16)
+# Ensure the model is moved to the GPU
+device = "cuda" if torch.cuda.is_available() else "cpu"
+pipe.to(device)
+
+# Function to load image from URL
+def load_image_from_url(url):
+ response = requests.get(url)
+ return Image.open(BytesIO(response.content)).convert("RGB")
+
+# Function to prepare mask
+def prepare_mask(mask_image):
+ # Convert to grayscale
+ mask = mask_image.convert("L")
+ return mask
+
+# Function to convert numpy array to PIL Image
+def array_to_pil(array):
+ # Ensure the array is in uint8 format
+ if array.dtype != np.uint8:
+ if array.max() <= 1.0:
+ array = (array * 255).astype(np.uint8)
+ else:
+ array = array.astype(np.uint8)
+
+ # Handle different array shapes
+ if len(array.shape) == 3:
+ if array.shape[0] == 3: # If channels first
+ array = array.transpose(1, 2, 0)
+ return Image.fromarray(array)
+ elif len(array.shape) == 4: # If batch dimension
+ array = array[0]
+ if array.shape[0] == 3: # If channels first
+ array = array.transpose(1, 2, 0)
+ return Image.fromarray(array)
+ else:
+ raise ValueError(f"Unexpected array shape: {array.shape}")
+
+# Image and mask URLs
+image_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png'
+mask_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png'
+
+# Load the images
+image = load_image_from_url(image_url)
+mask_image = load_image_from_url(mask_url)
-# Provide prompt, image, mask image, and the starting and target points for drag editing.
-prompt = "prompt of the image"
-image = PIL.Image.open('/path/to/image')
-mask_image = PIL.Image.open('/path/to/mask_image')
-source_points = [[123, 456]]
-target_points = [[234, 567]]
+# Resize images to a size that's compatible with the model's latent space
+image = image.resize((512, 512))
+mask_image = mask_image.resize((512, 512))
-# train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
-pipe.train_lora(prompt, image)
+# Prepare the mask (keep as PIL Image)
+mask = prepare_mask(mask_image)
-output = pipe(prompt, image, mask_image, source_points, target_points)
-output_image = PIL.Image.fromarray(output)
+# Provide the prompt and points for drag editing
+prompt = "A cute dog"
+source_points = [[32, 32]] # Adjusted for 512x512 image
+target_points = [[64, 64]] # Adjusted for 512x512 image
+
+# Generate the output image
+output_array = pipe(
+ prompt=prompt,
+ image=image,
+ mask_image=mask,
+ source_points=source_points,
+ target_points=target_points
+)
+
+# Convert output array to PIL Image and save
+output_image = array_to_pil(output_array)
output_image.save("./output.png")
+print("Output image saved as './output.png'")
+
```
### Instaflow Pipeline
@@ -3700,9 +4374,10 @@ You can also combine it with LORA out of the box, like
+
+**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
+
+#### Key features
+
+- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
+- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
+- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
+
+#### Usage example
+To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
+```py
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers.utils import load_image
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+dtype = torch.float16
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+ scheduler=scheduler,
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=dtype,
+).to(device)
+
+
+def preprocess_image(image_path, device):
+ image = to_tensor((load_image(image_path)))
+ image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+ if image.shape[1] != 3:
+ image = image.expand(-1, 3, -1, -1)
+ image = F.interpolate(image, (1024, 1024))
+ image = image.to(dtype).to(device)
+ return image
+
+def preprocess_mask(mask_path, device):
+ mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+ mask = mask.unsqueeze_(0).float() # 0 or 1
+ mask = F.interpolate(mask, (1024, 1024))
+ mask = gaussian_blur(mask, kernel_size=(77, 77))
+ mask[mask < 0.1] = 0
+ mask[mask >= 0.1] = 1
+ mask = mask.to(dtype).to(device)
+ return mask
+
+prompt = "" # Set prompt to null
+seed=123
+generator = torch.Generator(device=device).manual_seed(seed)
+source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+source_image = preprocess_image(source_image_path, device)
+mask = preprocess_mask(mask_path, device)
+
+image = pipeline(
+ prompt=prompt,
+ image=source_image,
+ mask_image=mask,
+ height=1024,
+ width=1024,
+ AAS=True, # enable AAS
+ strength=0.8, # inpainting strength
+ rm_guidance_scale=9, # removal guidance scale
+ ss_steps = 9, # similarity suppression steps
+ ss_scale = 0.3, # similarity suppression scale
+ AAS_start_step=0, # AAS start step
+ AAS_start_layer=34, # AAS start layer
+ AAS_end_layer=70, # AAS end layer
+ num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+ generator=generator,
+ guidance_scale=1,
+).images[0]
+image.save('./removed_img.png')
+print("Object removal completed")
+```
+
+| Source Image | Mask | Output |
+| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
+|  |  |  |
+
# Perturbed-Attention Guidance
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
@@ -4445,3 +5207,230 @@ grid_image.save(grid_dir + "sample.png")
`pag_scale` : guidance scale of PAG (ex: 5.0)
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
+
+# PIXART-α Controlnet pipeline
+
+[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md)
+
+This the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers.
+
+## Example Usage
+
+This example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper.
+
+```py
+import sys
+import os
+import torch
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+
+from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
+from diffusers.utils import load_image
+
+from diffusers.image_processor import PixArtImageProcessor
+
+from controlnet_aux import HEDdetector
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
+
+controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
+
+weight_dtype = torch.float16
+image_size = 1024
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+torch.manual_seed(0)
+
+# load controlnet
+controlnet = PixArtControlNetAdapterModel.from_pretrained(
+ controlnet_repo_id,
+ torch_dtype=weight_dtype,
+ use_safetensors=True,
+).to(device)
+
+pipe = PixArtAlphaControlnetPipeline.from_pretrained(
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
+ controlnet=controlnet,
+ torch_dtype=weight_dtype,
+ use_safetensors=True,
+).to(device)
+
+images_path = "images"
+control_image_file = "0_7.jpg"
+
+prompt = "battleship in space, galaxy in background"
+
+control_image_name = control_image_file.split('.')[0]
+
+control_image = load_image(f"{images_path}/{control_image_file}")
+print(control_image.size)
+height, width = control_image.size
+
+hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
+
+condition_transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB')),
+ T.CenterCrop([image_size, image_size]),
+])
+
+control_image = condition_transform(control_image)
+hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
+
+hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
+
+# run pipeline
+with torch.no_grad():
+ out = pipe(
+ prompt=prompt,
+ image=hed_edge,
+ num_inference_steps=14,
+ guidance_scale=4.5,
+ height=image_size,
+ width=image_size,
+ )
+
+ out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
+
+```
+
+In the folder examples/pixart there is also a script that can be used to train new models.
+Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
+
+# CogVideoX DDIM Inversion Pipeline
+
+This implementation performs DDIM inversion on the video based on CogVideoX and uses guided attention to reconstruct or edit the inversion latents.
+
+## Example Usage
+
+```python
+import torch
+
+from examples.community.cogvideox_ddim_inversion import CogVideoXPipelineForDDIMInversion
+
+
+# Load pretrained pipeline
+pipeline = CogVideoXPipelineForDDIMInversion.from_pretrained(
+ "THUDM/CogVideoX1.5-5B",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# Run DDIM inversion, and the videos will be generated in the output_path
+output = pipeline_for_inversion(
+ prompt="prompt that describes the edited video",
+ video_path="path/to/input.mp4",
+ guidance_scale=6.0,
+ num_inference_steps=50,
+ skip_frames_start=0,
+ skip_frames_end=0,
+ frame_sample_step=None,
+ max_num_frames=81,
+ width=720,
+ height=480,
+ seed=42,
+)
+pipeline.export_latents_to_video(output.inverse_latents[-1], "path/to/inverse_video.mp4", fps=8)
+pipeline.export_latents_to_video(output.recon_latents[-1], "path/to/recon_video.mp4", fps=8)
+```
+# FaithDiff Stable Diffusion XL Pipeline
+
+[Project](https://jychen9811.github.io/FaithDiff_page/) / [GitHub](https://github.com/JyChen9811/FaithDiff/)
+
+This the implementation of the FaithDiff pipeline for SDXL, adapted to use the HuggingFace Diffusers.
+
+For more details see the project links above.
+
+## Example Usage
+
+This example upscale and restores a low-quality image. The input image has a resolution of 512x512 and will be upscaled at a scale of 2x, to a final resolution of 1024x1024. It is possible to upscale to a larger scale, but it is recommended that the input image be at least 1024x1024 in these cases. To upscale this image by 4x, for example, it would be recommended to re-input the result into a new 2x processing, thus performing progressive scaling.
+
+````py
+import random
+import numpy as np
+import torch
+from diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler
+from huggingface_hub import hf_hub_download
+from diffusers.utils import load_image
+from PIL import Image
+
+device = "cuda"
+dtype = torch.float16
+MAX_SEED = np.iinfo(np.int32).max
+
+# Download weights for additional unet layers
+model_file = hf_hub_download(
+ "jychen9811/FaithDiff",
+ filename="FaithDiff.bin", local_dir="./proc_data/faithdiff", local_dir_use_symlinks=False
+)
+
+# Initialize the models and pipeline
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
+
+model_id = "SG161222/RealVisXL_V4.0"
+pipe = DiffusionPipeline.from_pretrained(
+ model_id,
+ torch_dtype=dtype,
+ vae=vae,
+ unet=None, #<- Do not load with original model.
+ custom_pipeline="pipeline_faithdiff_stable_diffusion_xl",
+ use_safetensors=True,
+ variant="fp16",
+).to(device)
+
+# Here we need use pipeline internal unet model
+pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
+
+# Load aditional layers to the model
+pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype)
+
+# Enable vae tiling
+pipe.set_encoder_tile_settings()
+pipe.enable_vae_tiling()
+
+# Optimization
+pipe.enable_model_cpu_offload()
+
+# Set selected scheduler
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+#input params
+prompt = "The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. "
+upscale = 2 # scale here
+start_point = "lr" # or "noise"
+latent_tiled_overlap = 0.5
+latent_tiled_size = 1024
+
+# Load image
+lq_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png")
+original_height = lq_image.height
+original_width = lq_image.width
+print(f"Current resolution: H:{original_height} x W:{original_width}")
+
+width = original_width * int(upscale)
+height = original_height * int(upscale)
+print(f"Final resolution: H:{height} x W:{width}")
+
+# Restoration
+image = lq_image.resize((width, height), Image.LANCZOS)
+input_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image)
+
+generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED))
+gen_image = pipe(lr_img=input_image,
+ prompt = prompt,
+ num_inference_steps=20,
+ guidance_scale=5,
+ generator=generator,
+ start_point=start_point,
+ height = height_now,
+ width=width_now,
+ overlap=latent_tiled_overlap,
+ target_size=(latent_tiled_size, latent_tiled_size)
+ ).images[0]
+
+cropped_image = gen_image.crop((0, 0, width_init, height_init))
+cropped_image.save("data/result.png")
+````
+### Result
+[ ](https://imgsli.com/MzY1NzE2)
\ No newline at end of file
diff --git a/examples/community/README_community_scripts.md b/examples/community/README_community_scripts.md
index 2c2f549a2bd5..3c9ad0d89bb4 100644
--- a/examples/community/README_community_scripts.md
+++ b/examples/community/README_community_scripts.md
@@ -6,9 +6,9 @@ If a community script doesn't work as expected, please open an issue and ping th
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
-| Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)|
-| asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)|
-| Prompt scheduling callback |Allows changing prompts during a generation | [Prompt Scheduling](#prompt-scheduling ) | | [hlky](https://github.com/hlky)|
+| Using IP-Adapter with Negative Noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ip_adapter_negative_noise.ipynb) | [Álvaro Somoza](https://github.com/asomoza)|
+| Asymmetric Tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#Asymmetric-Tiling ) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/asymetric_tiling.ipynb) | [alexisrolland](https://github.com/alexisrolland)|
+| Prompt Scheduling Callback |Allows changing prompts during a generation | [Prompt Scheduling-Callback](#Prompt-Scheduling-Callback ) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_scheduling_callback.ipynb) | [hlky](https://github.com/hlky)|
## Example usages
@@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
from diffusers.configuration_utils import register_to_config
import torch
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Tuple, Union
+
+
+class SDPromptSchedulingCallback(PipelineCallback):
+ @register_to_config
+ def __init__(
+ self,
+ encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
+ cutoff_step_ratio=None,
+ cutoff_step_index=None,
+ ):
+ super().__init__(
+ cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
+ )
+
+ tensor_inputs = ["prompt_embeds"]
+
+ def callback_fn(
+ self, pipeline, step_index, timestep, callback_kwargs
+ ) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+ if isinstance(self.config.encoded_prompt, tuple):
+ prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
+ else:
+ prompt_embeds = self.config.encoded_prompt
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index
+ if cutoff_step_index is not None
+ else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ if pipeline.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ return callback_kwargs
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
@@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
pipeline.safety_checker = None
pipeline.requires_safety_checker = False
+callback = MultiPipelineCallbacks(
+ [
+ SDPromptSchedulingCallback(
+ encoded_prompt=pipeline.encode_prompt(
+ prompt=f"prompt {index}",
+ negative_prompt=f"negative prompt {index}",
+ device=pipeline._execution_device,
+ num_images_per_prompt=1,
+ # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
+ do_classifier_free_guidance=True,
+ ),
+ cutoff_step_index=index,
+ ) for index in range(1, 20)
+ ]
+)
-class SDPromptScheduleCallback(PipelineCallback):
+image = pipeline(
+ prompt="prompt"
+ negative_prompt="negative prompt",
+ callback_on_step_end=callback,
+ callback_on_step_end_tensor_inputs=["prompt_embeds"],
+).images[0]
+torch.cuda.empty_cache()
+image.save('image.png')
+```
+
+```python
+from diffusers import StableDiffusionXLPipeline
+from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
+from diffusers.configuration_utils import register_to_config
+import torch
+from typing import Any, Dict, Tuple, Union
+
+
+class SDXLPromptSchedulingCallback(PipelineCallback):
@register_to_config
def __init__(
self,
- prompt: str,
- negative_prompt: Optional[str] = None,
- num_images_per_prompt: int = 1,
- cutoff_step_ratio=1.0,
+ encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
+ add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
+ add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
+ cutoff_step_ratio=None,
cutoff_step_index=None,
):
super().__init__(
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
)
- tensor_inputs = ["prompt_embeds"]
+ tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
def callback_fn(
self, pipeline, step_index, timestep, callback_kwargs
) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
+ if isinstance(self.config.encoded_prompt, tuple):
+ prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
+ else:
+ prompt_embeds = self.config.encoded_prompt
+ if isinstance(self.config.add_text_embeds, tuple):
+ add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds
+ else:
+ add_text_embeds = self.config.add_text_embeds
+ if isinstance(self.config.add_time_ids, tuple):
+ add_time_ids, negative_add_time_ids = self.config.add_time_ids
+ else:
+ add_time_ids = self.config.add_time_ids
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
@@ -284,32 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback):
)
if step_index == cutoff_step:
- prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
- prompt=self.config.prompt,
- negative_prompt=self.config.negative_prompt,
- device=pipeline._execution_device,
- num_images_per_prompt=self.config.num_images_per_prompt,
- do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
- )
if pipeline.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
return callback_kwargs
-callback = MultiPipelineCallbacks(
- [
- SDPromptScheduleCallback(
- prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
- negative_prompt="Deformed, ugly, bad anatomy",
- cutoff_step_ratio=0.25,
+
+pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True,
+).to("cuda")
+
+callbacks = []
+for index in range(1, 20):
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = pipeline.encode_prompt(
+ prompt=f"prompt {index}",
+ negative_prompt=f"prompt {index}",
+ device=pipeline._execution_device,
+ num_images_per_prompt=1,
+ # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
+ do_classifier_free_guidance=True,
+ )
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ add_time_ids = pipeline._get_add_time_ids(
+ (1024, 1024),
+ (0, 0),
+ (1024, 1024),
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ negative_add_time_ids = pipeline._get_add_time_ids(
+ (1024, 1024),
+ (0, 0),
+ (1024, 1024),
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ callbacks.append(
+ SDXLPromptSchedulingCallback(
+ encoded_prompt=(prompt_embeds, negative_prompt_embeds),
+ add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds),
+ add_time_ids=(add_time_ids, negative_add_time_ids),
+ cutoff_step_index=index,
)
- ]
-)
+ )
+
+
+callback = MultiPipelineCallbacks(callbacks)
image = pipeline(
- prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
- negative_prompt="Deformed, ugly, bad anatomy",
+ prompt="prompt",
+ negative_prompt="negative prompt",
callback_on_step_end=callback,
- callback_on_step_end_tensor_inputs=["prompt_embeds"],
+ callback_on_step_end_tensor_inputs=[
+ "prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ ],
).images[0]
```
diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py
new file mode 100644
index 000000000000..df736956485b
--- /dev/null
+++ b/examples/community/adaptive_mask_inpainting.py
@@ -0,0 +1,1469 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+import os
+import shutil
+from glob import glob
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import cv2
+import numpy as np
+import PIL.Image
+import requests
+import torch
+from detectron2.config import get_cfg
+from detectron2.data import MetadataCatalog
+from detectron2.engine import DefaultPredictor
+from detectron2.projects import point_rend
+from detectron2.structures.instances import Instances
+from detectron2.utils.visualizer import ColorMode, Visualizer
+from packaging import version
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+AMI_INSTALL_MESSAGE = """
+
+Example Demo of Adaptive Mask Inpainting
+
+Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models
+Kim et al.
+ECCV-2024 (Oral)
+
+
+Please prepare the environment via
+
+```
+conda create --name ami python=3.9 -y
+conda activate ami
+
+conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y
+python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
+pip install easydict
+pip install diffusers==0.20.2 accelerate safetensors transformers
+pip install setuptools==59.5.0
+pip install opencv-python
+pip install numpy==1.24.1
+```
+
+
+Put the code inside the root of diffusers library (e.g., as '/home/username/diffusers/adaptive_mask_inpainting_example.py') and run the python code.
+
+
+
+
+"""
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> init_image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
+ ... )
+ >>> init_image = init_image.resize((512, 512))
+
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
+
+ >>> mask_image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
+ ... )
+ >>> mask_image = mask_image.resize((512, 512))
+
+
+ >>> def make_inpaint_condition(image, image_mask):
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
+
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
+ ... image = torch.from_numpy(image)
+ ... return image
+
+
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
+
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
+ ... )
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> image = pipe(
+ ... "a handsome man with ray-ban sunglasses",
+ ... num_inference_steps=20,
+ ... generator=generator,
+ ... eta=1.0,
+ ... image=init_image,
+ ... mask_image=mask_image,
+ ... control_image=control_image,
+ ... ).images[0]
+ ```
+"""
+
+
+def download_file(url, output_file, exist_ok: bool):
+ if exist_ok and os.path.exists(output_file):
+ return
+
+ response = requests.get(url, stream=True)
+
+ with open(output_file, "wb") as file:
+ for chunk in tqdm(response.iter_content(chunk_size=8192), desc=f"Downloading '{output_file}'..."):
+ if chunk:
+ file.write(chunk)
+
+
+def generate_video_from_imgs(images_save_directory, fps=15.0, delete_dir=True):
+ # delete videos if exists
+ if os.path.exists(f"{images_save_directory}.mp4"):
+ os.remove(f"{images_save_directory}.mp4")
+ if os.path.exists(f"{images_save_directory}_before_process.mp4"):
+ os.remove(f"{images_save_directory}_before_process.mp4")
+
+ # assume there are "enumerated" images under "images_save_directory"
+ assert os.path.isdir(images_save_directory)
+ ImgPaths = sorted(glob(f"{images_save_directory}/*"))
+
+ if len(ImgPaths) == 0:
+ print("\tSkipping, since there must be at least one image to create mp4\n")
+ else:
+ # mp4 configuration
+ video_path = images_save_directory + "_before_process.mp4"
+
+ # Get height and width config
+ images = sorted([ImgPath.split("/")[-1] for ImgPath in ImgPaths if ImgPath.endswith(".png")])
+ frame = cv2.imread(os.path.join(images_save_directory, images[0]))
+ height, width, channels = frame.shape
+
+ # create mp4 video writer
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ video = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
+ for image in images:
+ video.write(cv2.imread(os.path.join(images_save_directory, image)))
+ cv2.destroyAllWindows()
+ video.release()
+
+ # generated video is not compatible with HTML5. Post-process and change codec of video, so that it is applicable to HTML.
+ os.system(
+ f'ffmpeg -i "{images_save_directory}_before_process.mp4" -vcodec libx264 -f mp4 "{images_save_directory}.mp4" '
+ )
+
+ # remove group of images, and remove video before post-process.
+ if delete_dir and os.path.exists(images_save_directory):
+ shutil.rmtree(images_save_directory)
+ # remove 'before-process' video
+ if os.path.exists(f"{images_save_directory}_before_process.mp4"):
+ os.remove(f"{images_save_directory}_before_process.mp4")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
+def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+class AdaptiveMaskInpaintPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+
+ Args:
+ vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: Union[AutoencoderKL, AsymmetricAutoencoderKL],
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ # safety_checker: StableDiffusionSafetyChecker,
+ safety_checker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ self.register_adaptive_mask_model()
+ self.register_adaptive_mask_settings()
+
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
+ if unet is not None and unet.config.in_channels != 9:
+ logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ """ Preparation for Adaptive Mask inpainting """
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
+ time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
+ Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
+ iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ else:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ default_mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ use_adaptive_mask: bool = True,
+ enforce_full_mask_ratio: float = 0.5,
+ human_detection_thres: float = 0.008,
+ visualization_save_dir: str = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`PIL.Image.Image`):
+ `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked
+ out with `default_mask_image` and repainted according to `prompt`).
+ default_mask_image (`PIL.Image.Image`):
+ `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted
+ while black pixels are preserved. If `default_mask_image` is a PIL image, it is converted to a single channel
+ (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the
+ expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import AdaptiveMaskInpaintPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> default_mask_image = download_image(mask_url).resize((512, 512))
+
+ >>> pipe = AdaptiveMaskInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> image = pipe(prompt=prompt, image=init_image, default_mask_image=default_mask_image).images[0]
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 0. Default height and width to unet
+ width, height = image.size
+ # height = height or self.unet.config.sample_size * self.vae_scale_factor
+ # width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image (will be used later, once again)
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
+ image, default_mask_image, height, width, return_image=True
+ )
+ default_mask_image_np = np.array(default_mask_image).astype(np.uint8) / 255
+ mask_condition = mask.clone()
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `default_mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Denoising loop
+ mask_image_np = default_mask_image_np
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+ else:
+ raise NotImplementedError
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1 & predicted original sample x_0
+ outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
+ latents = outputs["prev_sample"] # x_t-1
+ pred_orig_latents = outputs["pred_original_sample"] # x_0
+
+ # run segmentation
+ if use_adaptive_mask:
+ if enforce_full_mask_ratio > 0.0:
+ use_default_mask = t < self.scheduler.config.num_train_timesteps * enforce_full_mask_ratio
+ elif enforce_full_mask_ratio == 0.0:
+ use_default_mask = False
+ else:
+ raise NotImplementedError
+
+ pred_orig_image = self.decode_to_npuint8_image(pred_orig_latents)
+ dilate_num = self.adaptive_mask_settings.dilate_scheduler(i)
+ do_adapt_mask = self.adaptive_mask_settings.provoke_scheduler(i)
+ if do_adapt_mask:
+ mask, masked_image_latents, mask_image_np, vis_np = self.adapt_mask(
+ init_image,
+ pred_orig_image,
+ default_mask_image_np,
+ dilate_num=dilate_num,
+ use_default_mask=use_default_mask,
+ height=height,
+ width=width,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ device=device,
+ generator=generator,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ i=i,
+ human_detection_thres=human_detection_thres,
+ mask_image_np=mask_image_np,
+ )
+
+ if self.adaptive_mask_model.use_visualizer:
+ import matplotlib.pyplot as plt
+
+ # mask_image_new_colormap = np.clip(0.6 + (1.0 - mask_image_np), a_min=0.0, a_max=1.0) * 255
+
+ os.makedirs(visualization_save_dir, exist_ok=True)
+
+ # Image.fromarray(mask_image_new_colormap).convert("L").save(f"{visualization_save_dir}/masks/{i:05}.png")
+ plt.axis("off")
+ plt.subplot(1, 2, 1)
+ plt.imshow(mask_image_np)
+ plt.subplot(1, 2, 2)
+ plt.imshow(pred_orig_image)
+ plt.savefig(f"{visualization_save_dir}/{i:05}.png", bbox_inches="tight")
+ plt.close("all")
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents[:1]
+ init_mask = mask[:1]
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ condition_kwargs = {}
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
+ init_image_condition = init_image.clone()
+ init_image = self._encode_vae_image(init_image, generator=generator)
+ mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
+ condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if self.adaptive_mask_model.use_visualizer:
+ generate_video_from_imgs(images_save_directory=visualization_save_dir, fps=10, delete_dir=True)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def decode_to_npuint8_image(self, latents):
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **{})[
+ 0
+ ] # torch, float32, -1.~1.
+ image = self.image_processor.postprocess(image, output_type="pt", do_denormalize=[True] * image.shape[0])
+ image = (image.squeeze().permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8) # np, uint8, 0~255
+ return image
+
+ def register_adaptive_mask_settings(self):
+ from easydict import EasyDict
+
+ num_steps = 50
+
+ step_num = int(num_steps * 0.1)
+ final_step_num = num_steps - step_num * 7
+ # adaptive mask settings
+ self.adaptive_mask_settings = EasyDict(
+ dilate_scheduler=MaskDilateScheduler(
+ max_dilate_num=20,
+ num_inference_steps=num_steps,
+ schedule=[20] * step_num
+ + [10] * step_num
+ + [5] * step_num
+ + [4] * step_num
+ + [3] * step_num
+ + [2] * step_num
+ + [1] * step_num
+ + [0] * final_step_num,
+ ),
+ dilate_kernel=np.ones((3, 3), dtype=np.uint8),
+ provoke_scheduler=ProvokeScheduler(
+ num_inference_steps=num_steps,
+ schedule=list(range(2, 10 + 1, 2)) + list(range(12, 40 + 1, 2)) + [45],
+ is_zero_indexing=False,
+ ),
+ )
+
+ def register_adaptive_mask_model(self):
+ # declare segmentation model used for mask adaptation
+ use_visualizer = True
+ # assert not use_visualizer, \
+ # """
+ # If you plan to 'use_visualizer', USE WITH CAUTION.
+ # It creates a directory of images and masks, which is used for merging into a video.
+ # The procedure involves deleting the directory of images, which means that
+ # if you set the directory wrong you can have other important files blown away.
+ # """
+
+ self.adaptive_mask_model = PointRendPredictor(
+ # pointrend_thres=0.2,
+ pointrend_thres=0.9,
+ device="cuda" if torch.cuda.is_available() else "cpu",
+ use_visualizer=use_visualizer,
+ config_pth="pointrend_rcnn_R_50_FPN_3x_coco.yaml",
+ weights_pth="model_final_edd263.pkl",
+ )
+
+ def adapt_mask(self, init_image, pred_orig_image, default_mask_image, dilate_num, use_default_mask, **kwargs):
+ ## predict mask to use for adaptation
+ adapt_output = self.adaptive_mask_model(pred_orig_image) # vis can be None if 'use_visualizer' is False
+ mask = adapt_output["mask"]
+ vis = adapt_output["vis"]
+
+ ## if mask is empty or too small, use default_mask_image. else, use dilate and intersect with default_mask_image
+ if use_default_mask or mask.sum() < 512 * 512 * kwargs["human_detection_thres"]: # 0.005
+ # set mask as default mask
+ mask = default_mask_image # HxW
+
+ else:
+ ## timestep-adaptive mask
+ mask = cv2.dilate(
+ mask, self.adaptive_mask_settings.dilate_kernel, iterations=dilate_num
+ ) # dilate_kernel: np.ones((3,3), np.uint8)
+ mask = np.logical_and(mask, default_mask_image) # HxW
+
+ ## prepare mask as pt tensor format
+ mask = torch.tensor(mask, dtype=torch.float32).to(kwargs["device"])[None, None] # 1 x 1 x H x W
+ mask, masked_image = prepare_mask_and_masked_image(
+ init_image.to(kwargs["device"]), mask, kwargs["height"], kwargs["width"], return_image=False
+ )
+
+ mask_image_np = mask.clone().squeeze().detach().cpu().numpy()
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ kwargs["batch_size"] * kwargs["num_images_per_prompt"],
+ kwargs["height"],
+ kwargs["width"],
+ kwargs["prompt_embeds"].dtype,
+ kwargs["device"],
+ kwargs["generator"],
+ kwargs["do_classifier_free_guidance"],
+ )
+
+ return mask, masked_image_latents, mask_image_np, vis
+
+
+def seg2bbox(seg_mask: np.ndarray):
+ nonzero_i, nonzero_j = seg_mask.nonzero()
+ min_i, max_i = nonzero_i.min(), nonzero_i.max()
+ min_j, max_j = nonzero_j.min(), nonzero_j.max()
+
+ return np.array([min_j, min_i, max_j + 1, max_i + 1])
+
+
+def merge_bbox(bboxes: list):
+ assert len(bboxes) > 0
+
+ all_bboxes = np.stack(bboxes, axis=0) # shape: N_bbox X 4
+ merged_bbox = np.zeros_like(all_bboxes[0]) # shape: 4,
+
+ merged_bbox[0] = all_bboxes[:, 0].min()
+ merged_bbox[1] = all_bboxes[:, 1].min()
+ merged_bbox[2] = all_bboxes[:, 2].max()
+ merged_bbox[3] = all_bboxes[:, 3].max()
+
+ return merged_bbox
+
+
+class PointRendPredictor:
+ def __init__(
+ self,
+ cat_id_to_focus=0,
+ pointrend_thres=0.9,
+ device="cuda",
+ use_visualizer=False,
+ merge_mode="merge",
+ config_pth=None,
+ weights_pth=None,
+ ):
+ super().__init__()
+
+ # category id to focus (default: 0, which is human)
+ self.cat_id_to_focus = cat_id_to_focus
+
+ # setup coco metadata
+ self.coco_metadata = MetadataCatalog.get("coco_2017_val")
+ self.cfg = get_cfg()
+
+ # get segmentation model config
+ point_rend.add_pointrend_config(self.cfg) # --> Add PointRend-specific config
+ self.cfg.merge_from_file(config_pth)
+ self.cfg.MODEL.WEIGHTS = weights_pth
+ self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = pointrend_thres
+ self.cfg.MODEL.DEVICE = device
+
+ # get segmentation model
+ self.pointrend_seg_model = DefaultPredictor(self.cfg)
+
+ # settings for visualizer
+ self.use_visualizer = use_visualizer
+
+ # mask-merge mode
+ assert merge_mode in ["merge", "max-confidence"], f"'merge_mode': {merge_mode} not implemented."
+ self.merge_mode = merge_mode
+
+ def merge_mask(self, masks, scores=None):
+ if self.merge_mode == "merge":
+ mask = np.any(masks, axis=0)
+ elif self.merge_mode == "max-confidence":
+ mask = masks[np.argmax(scores)]
+ return mask
+
+ def vis_seg_on_img(self, image, mask):
+ if type(mask) == np.ndarray:
+ mask = torch.tensor(mask)
+ v = Visualizer(image, self.coco_metadata, scale=0.5, instance_mode=ColorMode.IMAGE_BW)
+ instances = Instances(image_size=image.shape[:2], pred_masks=mask if len(mask.shape) == 3 else mask[None])
+ vis = v.draw_instance_predictions(instances.to("cpu")).get_image()
+ return vis
+
+ def __call__(self, image):
+ # run segmentation
+ outputs = self.pointrend_seg_model(image)
+ instances = outputs["instances"]
+
+ # merge instances for the category-id to focus
+ is_class = instances.pred_classes == self.cat_id_to_focus
+ masks = instances.pred_masks[is_class]
+ masks = masks.detach().cpu().numpy() # [N, img_size, img_size]
+ mask = self.merge_mask(masks, scores=instances.scores[is_class])
+
+ return {
+ "asset_mask": None,
+ "mask": mask.astype(np.uint8),
+ "vis": self.vis_seg_on_img(image, mask) if self.use_visualizer else None,
+ }
+
+
+class MaskDilateScheduler:
+ def __init__(self, max_dilate_num=15, num_inference_steps=50, schedule=None):
+ super().__init__()
+ self.max_dilate_num = max_dilate_num
+ self.schedule = [num_inference_steps - i for i in range(num_inference_steps)] if schedule is None else schedule
+ assert len(self.schedule) == num_inference_steps
+
+ def __call__(self, i):
+ return min(self.max_dilate_num, self.schedule[i])
+
+
+class ProvokeScheduler:
+ def __init__(self, num_inference_steps=50, schedule=None, is_zero_indexing=False):
+ super().__init__()
+ if len(schedule) > 0:
+ if is_zero_indexing:
+ assert max(schedule) <= num_inference_steps - 1
+ else:
+ assert max(schedule) <= num_inference_steps
+
+ # register as self
+ self.is_zero_indexing = is_zero_indexing
+ self.schedule = schedule
+
+ def __call__(self, i):
+ if self.is_zero_indexing:
+ return i in self.schedule
+ else:
+ return i + 1 in self.schedule
diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py
index 6ba4b8c6e837..f23e8a207e36 100644
--- a/examples/community/checkpoint_merger.py
+++ b/examples/community/checkpoint_merger.py
@@ -92,9 +92,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
token = kwargs.pop("token", None)
variant = kwargs.pop("variant", None)
revision = kwargs.pop("revision", None)
- torch_dtype = kwargs.pop("torch_dtype", None)
+ torch_dtype = kwargs.pop("torch_dtype", torch.float32)
device_map = kwargs.pop("device_map", None)
+ if not isinstance(torch_dtype, torch.dtype):
+ torch_dtype = torch.float32
+ print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.")
+
alpha = kwargs.pop("alpha", 0.5)
interp = kwargs.pop("interp", None)
diff --git a/examples/community/cogvideox_ddim_inversion.py b/examples/community/cogvideox_ddim_inversion.py
new file mode 100644
index 000000000000..e9d1746d2d64
--- /dev/null
+++ b/examples/community/cogvideox_ddim_inversion.py
@@ -0,0 +1,645 @@
+"""
+This script performs DDIM inversion for video frames using a pre-trained model and generates
+a video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to
+process video frames, apply the DDIM inverse scheduler, and produce an output video.
+
+**Please notice that this script is based on the CogVideoX 5B model, and would not generate
+a good result for 2B variants.**
+
+Usage:
+ python cogvideox_ddim_inversion.py
+ --model-path /path/to/model
+ --prompt "a prompt"
+ --video-path /path/to/video.mp4
+ --output-path /path/to/output
+
+For more details about the cli arguments, please run `python cogvideox_ddim_inversion.py --help`.
+
+Author:
+ LittleNyima
+"""
+
+import argparse
+import math
+import os
+from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
+
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as T
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
+from diffusers.models.autoencoders import AutoencoderKLCogVideoX
+from diffusers.models.embeddings import apply_rotary_emb
+from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
+from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
+from diffusers.utils import export_to_video
+
+
+# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error.
+# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
+import decord # isort: skip
+
+
+class DDIMInversionArguments(TypedDict):
+ model_path: str
+ prompt: str
+ video_path: str
+ output_path: str
+ guidance_scale: float
+ num_inference_steps: int
+ skip_frames_start: int
+ skip_frames_end: int
+ frame_sample_step: Optional[int]
+ max_num_frames: int
+ width: int
+ height: int
+ fps: int
+ dtype: torch.dtype
+ seed: int
+ device: torch.device
+
+
+def get_args() -> DDIMInversionArguments:
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", type=str, required=True, help="Path of the pretrained model")
+ parser.add_argument("--prompt", type=str, required=True, help="Prompt for the direct sample procedure")
+ parser.add_argument("--video_path", type=str, required=True, help="Path of the video for inversion")
+ parser.add_argument("--output_path", type=str, default="output", help="Path of the output videos")
+ parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
+ parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
+ parser.add_argument("--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start")
+ parser.add_argument("--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end")
+ parser.add_argument("--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames")
+ parser.add_argument("--max_num_frames", type=int, default=81, help="Max number of sampled frames")
+ parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
+ parser.add_argument("--height", type=int, default=480, help="Resized height of the video frames")
+ parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
+ parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model")
+ parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
+ parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference")
+
+ args = parser.parse_args()
+ args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
+ args.device = torch.device(args.device)
+
+ return DDIMInversionArguments(**vars(args))
+
+
+class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
+ def __init__(self):
+ super().__init__()
+
+ def calculate_attention(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn: Attention,
+ batch_size: int,
+ image_seq_length: int,
+ text_seq_length: int,
+ attention_mask: Optional[torch.Tensor],
+ image_rotary_emb: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Core attention computation with inversion-guided RoPE integration.
+
+ Args:
+ query (`torch.Tensor`): `[batch_size, seq_len, dim]` query tensor
+ key (`torch.Tensor`): `[batch_size, seq_len, dim]` key tensor
+ value (`torch.Tensor`): `[batch_size, seq_len, dim]` value tensor
+ attn (`Attention`): Parent attention module with projection layers
+ batch_size (`int`): Effective batch size (after chunk splitting)
+ image_seq_length (`int`): Length of image feature sequence
+ text_seq_length (`int`): Length of text feature sequence
+ attention_mask (`Optional[torch.Tensor]`): Attention mask tensor
+ image_rotary_emb (`Optional[torch.Tensor]`): Rotary embeddings for image positions
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ (1) hidden_states: [batch_size, image_seq_length, dim] processed image features
+ (2) encoder_hidden_states: [batch_size, text_seq_length, dim] processed text features
+ """
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ if key.size(2) == query.size(2): # Attention for reference hidden states
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+ else: # RoPE should be applied to each group of image tokens
+ key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb(
+ key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb
+ )
+ key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
+ key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
+ )
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Process the dual-path attention for the inversion-guided denoising procedure.
+
+ Args:
+ attn (`Attention`): Parent attention module
+ hidden_states (`torch.Tensor`): `[batch_size, image_seq_len, dim]` Image tokens
+ encoder_hidden_states (`torch.Tensor`): `[batch_size, text_seq_len, dim]` Text tokens
+ attention_mask (`Optional[torch.Tensor]`): Optional attention mask
+ image_rotary_emb (`Optional[torch.Tensor]`): Rotary embeddings for image tokens
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ (1) Final hidden states: `[batch_size, image_seq_length, dim]` Resulting image tokens
+ (2) Final encoder states: `[batch_size, text_seq_length, dim]` Resulting text tokens
+ """
+ image_seq_length = hidden_states.size(1)
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query, query_reference = query.chunk(2)
+ key, key_reference = key.chunk(2)
+ value, value_reference = value.chunk(2)
+ batch_size = batch_size // 2
+
+ hidden_states, encoder_hidden_states = self.calculate_attention(
+ query=query,
+ key=torch.cat((key, key_reference), dim=1),
+ value=torch.cat((value, value_reference), dim=1),
+ attn=attn,
+ batch_size=batch_size,
+ image_seq_length=image_seq_length,
+ text_seq_length=text_seq_length,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states_reference, encoder_hidden_states_reference = self.calculate_attention(
+ query=query_reference,
+ key=key_reference,
+ value=value_reference,
+ attn=attn,
+ batch_size=batch_size,
+ image_seq_length=image_seq_length,
+ text_seq_length=text_seq_length,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ return (
+ torch.cat((hidden_states, hidden_states_reference)),
+ torch.cat((encoder_hidden_states, encoder_hidden_states_reference)),
+ )
+
+
+class OverrideAttnProcessors:
+ r"""
+ Context manager for temporarily overriding attention processors in CogVideo transformer blocks.
+
+ Designed for DDIM inversion process, replaces original attention processors with
+ `CogVideoXAttnProcessor2_0ForDDIMInversion` and restores them upon exit. Uses Python context manager
+ pattern to safely manage processor replacement.
+
+ Typical usage:
+ ```python
+ with OverrideAttnProcessors(transformer):
+ # Perform DDIM inversion operations
+ ```
+
+ Args:
+ transformer (`CogVideoXTransformer3DModel`):
+ The transformer model containing attention blocks to be modified. Should have
+ `transformer_blocks` attribute containing `CogVideoXBlock` instances.
+ """
+
+ def __init__(self, transformer: CogVideoXTransformer3DModel):
+ self.transformer = transformer
+ self.original_processors = {}
+
+ def __enter__(self):
+ for block in self.transformer.transformer_blocks:
+ block = cast(CogVideoXBlock, block)
+ self.original_processors[id(block)] = block.attn1.get_processor()
+ block.attn1.set_processor(CogVideoXAttnProcessor2_0ForDDIMInversion())
+
+ def __exit__(self, _0, _1, _2):
+ for block in self.transformer.transformer_blocks:
+ block = cast(CogVideoXBlock, block)
+ block.attn1.set_processor(self.original_processors[id(block)])
+
+
+def get_video_frames(
+ video_path: str,
+ width: int,
+ height: int,
+ skip_frames_start: int,
+ skip_frames_end: int,
+ max_num_frames: int,
+ frame_sample_step: Optional[int],
+) -> torch.FloatTensor:
+ """
+ Extract and preprocess video frames from a video file for VAE processing.
+
+ Args:
+ video_path (`str`): Path to input video file
+ width (`int`): Target frame width for decoding
+ height (`int`): Target frame height for decoding
+ skip_frames_start (`int`): Number of frames to skip at video start
+ skip_frames_end (`int`): Number of frames to skip at video end
+ max_num_frames (`int`): Maximum allowed number of output frames
+ frame_sample_step (`Optional[int]`):
+ Frame sampling step size. If None, automatically calculated as:
+ (total_frames - skipped_frames) // max_num_frames
+
+ Returns:
+ `torch.FloatTensor`: Preprocessed frames in `[F, C, H, W]` format where:
+ - `F`: Number of frames (adjusted to 4k + 1 for VAE compatibility)
+ - `C`: Channels (3 for RGB)
+ - `H`: Frame height
+ - `W`: Frame width
+ """
+ with decord.bridge.use_torch():
+ video_reader = decord.VideoReader(uri=video_path, width=width, height=height)
+ video_num_frames = len(video_reader)
+ start_frame = min(skip_frames_start, video_num_frames)
+ end_frame = max(0, video_num_frames - skip_frames_end)
+
+ if end_frame <= start_frame:
+ indices = [start_frame]
+ elif end_frame - start_frame <= max_num_frames:
+ indices = list(range(start_frame, end_frame))
+ else:
+ step = frame_sample_step or (end_frame - start_frame) // max_num_frames
+ indices = list(range(start_frame, end_frame, step))
+
+ frames = video_reader.get_batch(indices=indices)
+ frames = frames[:max_num_frames].float() # ensure that we don't go over the limit
+
+ # Choose first (4k + 1) frames as this is how many is required by the VAE
+ selected_num_frames = frames.size(0)
+ remainder = (3 + selected_num_frames) % 4
+ if remainder != 0:
+ frames = frames[:-remainder]
+ assert frames.size(0) % 4 == 1
+
+ # Normalize the frames
+ transform = T.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
+ frames = torch.stack(tuple(map(transform, frames)), dim=0)
+
+ return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
+
+
+class CogVideoXDDIMInversionOutput:
+ inverse_latents: torch.FloatTensor
+ recon_latents: torch.FloatTensor
+
+ def __init__(self, inverse_latents: torch.FloatTensor, recon_latents: torch.FloatTensor):
+ self.inverse_latents = inverse_latents
+ self.recon_latents = recon_latents
+
+
+class CogVideoXPipelineForDDIMInversion(CogVideoXPipeline):
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: CogVideoXDDIMScheduler,
+ ):
+ super().__init__(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.inverse_scheduler = DDIMInverseScheduler(**scheduler.config)
+
+ def encode_video_frames(self, video_frames: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ Encode video frames into latent space using Variational Autoencoder.
+
+ Args:
+ video_frames (`torch.FloatTensor`):
+ Input frames tensor in `[F, C, H, W]` format from `get_video_frames()`
+
+ Returns:
+ `torch.FloatTensor`: Encoded latents in `[1, F, D, H_latent, W_latent]` format where:
+ - `F`: Number of frames (same as input)
+ - `D`: Latent channel dimension
+ - `H_latent`: Latent space height (H // 2^vae.downscale_factor)
+ - `W_latent`: Latent space width (W // 2^vae.downscale_factor)
+ """
+ vae: AutoencoderKLCogVideoX = self.vae
+ video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
+ video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
+ latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
+ return latent_dist * vae.config.scaling_factor
+
+ @torch.no_grad()
+ def export_latents_to_video(self, latents: torch.FloatTensor, video_path: str, fps: int):
+ r"""
+ Decode latent vectors into video and export as video file.
+
+ Args:
+ latents (`torch.FloatTensor`): Encoded latents in `[B, F, D, H_latent, W_latent]` format from
+ `encode_video_frames()`
+ video_path (`str`): Output path for video file
+ fps (`int`): Target frames per second for output video
+ """
+ video = self.decode_latents(latents)
+ frames = self.video_processor.postprocess_video(video=video, output_type="pil")
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
+ export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
+
+ # Modified from CogVideoXPipeline.__call__
+ @torch.no_grad()
+ def sample(
+ self,
+ latents: torch.FloatTensor,
+ scheduler: Union[DDIMInverseScheduler, CogVideoXDDIMScheduler],
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ reference_latents: torch.FloatTensor = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Execute the core sampling loop for video generation/inversion using CogVideoX.
+
+ Implements the full denoising trajectory recording for both DDIM inversion and
+ generation processes. Supports dynamic classifier-free guidance and reference
+ latent conditioning.
+
+ Args:
+ latents (`torch.FloatTensor`):
+ Initial noise tensor of shape `[B, F, C, H, W]`.
+ scheduler (`Union[DDIMInverseScheduler, CogVideoXDDIMScheduler]`):
+ Scheduling strategy for diffusion process. Use:
+ (1) `DDIMInverseScheduler` for inversion
+ (2) `CogVideoXDDIMScheduler` for generation
+ prompt (`Optional[Union[str, List[str]]]`):
+ Text prompt(s) for conditional generation. Defaults to unconditional.
+ negative_prompt (`Optional[Union[str, List[str]]]`):
+ Negative prompt(s) for guidance. Requires `guidance_scale > 1`.
+ num_inference_steps (`int`):
+ Number of denoising steps. Affects quality/compute trade-off.
+ guidance_scale (`float`):
+ Classifier-free guidance weight. 1.0 = no guidance.
+ use_dynamic_cfg (`bool`):
+ Enable time-varying guidance scale (cosine schedule)
+ eta (`float`):
+ DDIM variance parameter (0 = deterministic process)
+ generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`):
+ Random number generator(s) for reproducibility
+ attention_kwargs (`Optional[Dict[str, Any]]`):
+ Custom parameters for attention modules
+ reference_latents (`torch.FloatTensor`):
+ Reference latent trajectory for conditional sampling. Shape should match
+ `[T, B, F, C, H, W]` where `T` is number of timesteps
+
+ Returns:
+ `torch.FloatTensor`:
+ Full denoising trajectory tensor of shape `[T, B, F, C, H, W]`.
+ """
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ if reference_latents is not None:
+ prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents.
+ latents = latents.to(device=device) * scheduler.init_noise_sigma
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
+ extra_step_kwargs = {}
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(
+ height=latents.size(3) * self.vae_scale_factor_spatial,
+ width=latents.size(4) * self.vae_scale_factor_spatial,
+ num_frames=latents.size(1),
+ device=device,
+ )
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
+
+ trajectory = torch.zeros_like(latents).unsqueeze(0).repeat(len(timesteps), 1, 1, 1, 1, 1)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if reference_latents is not None:
+ reference = reference_latents[i]
+ reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
+ latent_model_input = torch.cat([latent_model_input, reference], dim=0)
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ if reference_latents is not None: # Recover the original batch size
+ noise_pred, _ = noise_pred.chunk(2)
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the noisy sample x_t-1 -> x_t
+ latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ latents = latents.to(prompt_embeds.dtype)
+ trajectory[i] = latents
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
+ progress_bar.update()
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ return trajectory
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: str,
+ video_path: str,
+ guidance_scale: float,
+ num_inference_steps: int,
+ skip_frames_start: int,
+ skip_frames_end: int,
+ frame_sample_step: Optional[int],
+ max_num_frames: int,
+ width: int,
+ height: int,
+ seed: int,
+ ):
+ """
+ Performs DDIM inversion on a video to reconstruct it with a new prompt.
+
+ Args:
+ prompt (`str`): The text prompt to guide the reconstruction.
+ video_path (`str`): Path to the input video file.
+ guidance_scale (`float`): Scale for classifier-free guidance.
+ num_inference_steps (`int`): Number of denoising steps.
+ skip_frames_start (`int`): Number of frames to skip from the beginning of the video.
+ skip_frames_end (`int`): Number of frames to skip from the end of the video.
+ frame_sample_step (`Optional[int]`): Step size for sampling frames. If None, all frames are used.
+ max_num_frames (`int`): Maximum number of frames to process.
+ width (`int`): Width of the output video frames.
+ height (`int`): Height of the output video frames.
+ seed (`int`): Random seed for reproducibility.
+
+ Returns:
+ `CogVideoXDDIMInversionOutput`: Contains the inverse latents and reconstructed latents.
+ """
+ if not self.transformer.config.use_rotary_positional_embeddings:
+ raise NotImplementedError("This script supports CogVideoX 5B model only.")
+ video_frames = get_video_frames(
+ video_path=video_path,
+ width=width,
+ height=height,
+ skip_frames_start=skip_frames_start,
+ skip_frames_end=skip_frames_end,
+ max_num_frames=max_num_frames,
+ frame_sample_step=frame_sample_step,
+ ).to(device=self.device)
+ video_latents = self.encode_video_frames(video_frames=video_frames)
+ inverse_latents = self.sample(
+ latents=video_latents,
+ scheduler=self.inverse_scheduler,
+ prompt="",
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator(device=self.device).manual_seed(seed),
+ )
+ with OverrideAttnProcessors(transformer=self.transformer):
+ recon_latents = self.sample(
+ latents=torch.randn_like(video_latents),
+ scheduler=self.scheduler,
+ prompt=prompt,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator(device=self.device).manual_seed(seed),
+ reference_latents=reversed(inverse_latents),
+ )
+ return CogVideoXDDIMInversionOutput(
+ inverse_latents=inverse_latents,
+ recon_latents=recon_latents,
+ )
+
+
+if __name__ == "__main__":
+ arguments = get_args()
+ pipeline = CogVideoXPipelineForDDIMInversion.from_pretrained(
+ arguments.pop("model_path"),
+ torch_dtype=arguments.pop("dtype"),
+ ).to(device=arguments.pop("device"))
+
+ output_path = arguments.pop("output_path")
+ fps = arguments.pop("fps")
+ inverse_video_path = os.path.join(output_path, f"{arguments.get('video_path')}_inversion.mp4")
+ recon_video_path = os.path.join(output_path, f"{arguments.get('video_path')}_reconstruction.mp4")
+
+ # Run DDIM inversion
+ output = pipeline(**arguments)
+ pipeline.export_latents_to_video(output.inverse_latents[-1], inverse_video_path, fps)
+ pipeline.export_latents_to_video(output.recon_latents[-1], recon_video_path, fps)
diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py
index 46d12ba1f2aa..024818daf186 100644
--- a/examples/community/composable_stable_diffusion.py
+++ b/examples/community/composable_stable_diffusion.py
@@ -89,7 +89,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -103,7 +103,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -132,10 +132,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -162,7 +166,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py
index ac977f79abec..a7bc892ddf93 100644
--- a/examples/community/edict_pipeline.py
+++ b/examples/community/edict_pipeline.py
@@ -35,7 +35,7 @@ def __init__(
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
diff --git a/examples/community/fresco_v2v.py b/examples/community/fresco_v2v.py
index ab191ecf0d81..d6c2683f1d86 100644
--- a/examples/community/fresco_v2v.py
+++ b/examples/community/fresco_v2v.py
@@ -404,10 +404,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -1342,7 +1343,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
diff --git a/examples/community/gluegen.py b/examples/community/gluegen.py
index 91026c5d966f..54cc562d5583 100644
--- a/examples/community/gluegen.py
+++ b/examples/community/gluegen.py
@@ -221,7 +221,7 @@ def __init__(
language_adapter=language_adapter,
tensor_norm=tensor_norm,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py
index 4dfb7a39155f..292c9aa2bc47 100644
--- a/examples/community/img2img_inpainting.py
+++ b/examples/community/img2img_inpainting.py
@@ -95,7 +95,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/community/instaflow_one_step.py b/examples/community/instaflow_one_step.py
index 3fef02287186..e726b42756ee 100644
--- a/examples/community/instaflow_one_step.py
+++ b/examples/community/instaflow_one_step.py
@@ -109,7 +109,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -123,7 +123,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -152,10 +152,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -182,7 +186,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py
index 52b2707f33f7..99614635ee13 100644
--- a/examples/community/interpolate_stable_diffusion.py
+++ b/examples/community/interpolate_stable_diffusion.py
@@ -86,7 +86,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py
index c7dc775eeee3..648bf2933145 100644
--- a/examples/community/ip_adapter_face_id.py
+++ b/examples/community/ip_adapter_face_id.py
@@ -191,7 +191,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -205,7 +205,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -234,10 +234,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -265,7 +269,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/kohya_hires_fix.py b/examples/community/kohya_hires_fix.py
index 0e36f32b19a3..ddbb28896e13 100644
--- a/examples/community/kohya_hires_fix.py
+++ b/examples/community/kohya_hires_fix.py
@@ -463,6 +463,6 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/latent_consistency_img2img.py b/examples/community/latent_consistency_img2img.py
index 5fe53ab6b830..6c532c7f76c1 100644
--- a/examples/community/latent_consistency_img2img.py
+++ b/examples/community/latent_consistency_img2img.py
@@ -69,7 +69,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
diff --git a/examples/community/latent_consistency_interpolate.py b/examples/community/latent_consistency_interpolate.py
index 84adc125b191..34cdb0fec73b 100644
--- a/examples/community/latent_consistency_interpolate.py
+++ b/examples/community/latent_consistency_interpolate.py
@@ -273,7 +273,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/latent_consistency_txt2img.py b/examples/community/latent_consistency_txt2img.py
index 9f25a6db2722..7b60f5bb875c 100755
--- a/examples/community/latent_consistency_txt2img.py
+++ b/examples/community/latent_consistency_txt2img.py
@@ -67,7 +67,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py
index 49c074911354..129793dae6b0 100644
--- a/examples/community/llm_grounded_diffusion.py
+++ b/examples/community/llm_grounded_diffusion.py
@@ -336,7 +336,7 @@ def __init__(
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -350,7 +350,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -379,10 +379,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -410,7 +414,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index ec27acdce331..32baf500d456 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -496,7 +496,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -510,7 +510,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -539,10 +539,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -568,7 +572,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index 13d1e2a1156a..4d9683b73fc4 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -673,12 +673,16 @@ def __init__(
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -827,7 +831,9 @@ def encode_prompt(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -879,7 +885,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1766,7 +1773,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1917,7 +1924,22 @@ def denoising_value_valid(dnv):
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py
index 92f01d046ef9..cdee18e0eee9 100644
--- a/examples/community/marigold_depth_estimation.py
+++ b/examples/community/marigold_depth_estimation.py
@@ -43,7 +43,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
class MarigoldDepthOutput(BaseOutput):
diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py
index 7ac0ab542910..4895bd150114 100644
--- a/examples/community/matryoshka.py
+++ b/examples/community/matryoshka.py
@@ -80,7 +80,6 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
- is_torch_version,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -868,24 +867,8 @@ def forward(
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1029,18 +1012,7 @@ def forward(
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1049,12 +1021,7 @@ def custom_forward(*inputs):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
@@ -1191,24 +1158,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1282,10 +1233,6 @@ def __init__(
]
)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -1364,20 +1311,9 @@ def forward(
# Blocks
for block in self.transformer_blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -1385,7 +1321,6 @@ def custom_forward(*inputs):
timestep,
cross_attention_kwargs,
class_labels,
- **ckpt_kwargs,
)
else:
hidden_states = block(
@@ -2724,10 +2659,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -2806,10 +2737,11 @@ def get_time_embed(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -3766,7 +3698,7 @@ def __init__(
else:
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -3780,7 +3712,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ # if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
# deprecation_message = (
# f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
# " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -3793,10 +3725,14 @@ def __init__(
# new_config["clip_sample"] = False
# scheduler._internal_dict = FrozenDict(new_config)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
diff --git a/examples/community/mixture_tiling_sdxl.py b/examples/community/mixture_tiling_sdxl.py
new file mode 100644
index 000000000000..bd56ddb3d61d
--- /dev/null
+++ b/examples/community/mixture_tiling_sdxl.py
@@ -0,0 +1,1237 @@
+# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+)
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+try:
+ from ligo.segments import segment
+except ImportError:
+ raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
+ """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in pixel space
+ - Ending coordinates of rows in pixel space
+ - Starting coordinates of columns in pixel space
+ - Ending coordinates of columns in pixel space
+ """
+ px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)
+ px_row_end = px_row_init + tile_height
+ px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)
+ px_col_end = px_col_init + tile_width
+ return px_row_init, px_row_end, px_col_init, px_col_end
+
+
+def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end):
+ """Translates coordinates in pixel space to coordinates in latent space"""
+ return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8
+
+
+def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
+ """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in latent space
+ - Ending coordinates of rows in latent space
+ - Starting coordinates of columns in latent space
+ - Ending coordinates of columns in latent space
+ """
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end)
+
+
+def _tile2latent_exclusive_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns
+):
+ """Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in latent space
+ - Ending coordinates of rows in latent space
+ - Starting coordinates of columns in latent space
+ - Ending coordinates of columns in latent space
+ """
+ row_init, row_end, col_init, col_end = _tile2latent_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ row_segment = segment(row_init, row_end)
+ col_segment = segment(col_init, col_end)
+ # Iterate over the rest of tiles, clipping the region for the current tile
+ for row in range(rows):
+ for column in range(columns):
+ if row != tile_row and column != tile_col:
+ clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(
+ row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ row_segment = row_segment - segment(clip_row_init, clip_row_end)
+ col_segment = col_segment - segment(clip_col_init, clip_col_end)
+ # return row_init, row_end, col_init, col_end
+ return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
+
+
+def _get_crops_coords_list(num_rows, num_cols, output_width):
+ """
+ Generates a list of lists of `crops_coords_top_left` tuples for focusing on
+ different horizontal parts of an image, and repeats this list for the specified
+ number of rows in the output structure.
+
+ This function calculates `crops_coords_top_left` tuples to create horizontal
+ focus variations (like left, center, right focus) based on `output_width`
+ and `num_cols` (which represents the number of horizontal focus points/columns).
+ It then repeats the *list* of these horizontal focus tuples `num_rows` times to
+ create the final list of lists output structure.
+
+ Args:
+ num_rows (int): The desired number of rows in the output list of lists.
+ This determines how many times the list of horizontal
+ focus variations will be repeated.
+ num_cols (int): The number of horizontal focus points (columns) to generate.
+ This determines how many horizontal focus variations are
+ created based on dividing the `output_width`.
+ output_width (int): The desired width of the output image.
+
+ Returns:
+ list[list[tuple[int, int]]]: A list of lists of tuples. Each inner list
+ contains `num_cols` tuples of `(ctop, cleft)`,
+ representing horizontal focus points. The outer list
+ contains `num_rows` such inner lists.
+ """
+ crops_coords_list = []
+ if num_cols <= 0:
+ crops_coords_list = []
+ elif num_cols == 1:
+ crops_coords_list = [(0, 0)]
+ else:
+ section_width = output_width / num_cols
+ for i in range(num_cols):
+ cleft = int(round(i * section_width))
+ crops_coords_list.append((0, cleft))
+
+ result_list = []
+ for _ in range(num_rows):
+ result_list.append(list(crops_coords_list))
+
+ return result_list
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXLTilingPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ class SeedTilesMode(Enum):
+ """Modes in which the latents of a particular tile can be re-seeded"""
+
+ FULL = "full"
+ EXCLUSIVE = "exclusive"
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, grid_cols, seed_tiles_mode, tiles_mode):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
+ raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
+
+ if not all(len(row) == grid_cols for row in prompt):
+ raise ValueError("All prompt rows must have the same number of prompt columns")
+
+ if not isinstance(seed_tiles_mode, str) and (
+ not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
+ ):
+ raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
+
+ if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):
+ raise ValueError(f"Seed tiles mode must be one of {tiles_mode}")
+
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ def _gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype):
+ """Generates a gaussian mask of weights for tile contributions"""
+ import numpy as np
+ from numpy import exp, pi, sqrt
+
+ latent_width = tile_width // 8
+ latent_height = tile_height // 8
+
+ var = 0.01
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
+ x_probs = [
+ exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
+ for x in range(latent_width)
+ ]
+ midpoint = latent_height / 2
+ y_probs = [
+ exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
+ for y in range(latent_height)
+ ]
+
+ weights_np = np.outer(y_probs, x_probs)
+ weights_torch = torch.tensor(weights_np, device=device)
+ weights_torch = weights_torch.to(dtype)
+ return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
+
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ FusedAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Optional[List[List[Tuple[int, int]]]] = None,
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Optional[List[List[Tuple[int, int]]]] = None,
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ tile_height: Optional[int] = 1024,
+ tile_width: Optional[int] = 1024,
+ tile_row_overlap: Optional[int] = 128,
+ tile_col_overlap: Optional[int] = 128,
+ guidance_scale_tiles: Optional[List[List[float]]] = None,
+ seed_tiles: Optional[List[List[int]]] = None,
+ seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
+ seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`List[List[Tuple[int, int]]]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`List[List[Tuple[int, int]]]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ tile_height (`int`, *optional*, defaults to 1024):
+ Height of each grid tile in pixels.
+ tile_width (`int`, *optional*, defaults to 1024):
+ Width of each grid tile in pixels.
+ tile_row_overlap (`int`, *optional*, defaults to 128):
+ Number of overlapping pixels between tiles in consecutive rows.
+ tile_col_overlap (`int`, *optional*, defaults to 128):
+ Number of overlapping pixels between tiles in consecutive columns.
+ guidance_scale_tiles (`List[List[float]]`, *optional*):
+ Specific weights for classifier-free guidance in each tile. If `None`, the value provided in `guidance_scale` will be used.
+ seed_tiles (`List[List[int]]`, *optional*):
+ Specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard `generator` parameter.
+ seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `"full"`):
+ Mode for seeding tiles, can be `"full"` or `"exclusive"`. If `"full"`, all the latents affected by the tile will be overridden. If `"exclusive"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.
+ seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):
+ A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.
+ **kwargs (`Dict[str, Any]`, *optional*):
+ Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLTilingPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLTilingPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+ negative_original_size = negative_original_size or (height, width)
+ negative_target_size = negative_target_size or (height, width)
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ grid_rows = len(prompt)
+ grid_cols = len(prompt[0])
+
+ tiles_mode = [mode.value for mode in self.SeedTilesMode]
+
+ if isinstance(seed_tiles_mode, str):
+ seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ grid_cols,
+ seed_tiles_mode,
+ tiles_mode,
+ )
+
+ if seed_reroll_regions is None:
+ seed_reroll_regions = []
+
+ batch_size = 1
+
+ device = self._execution_device
+
+ # update crops coords list
+ crops_coords_top_left = _get_crops_coords_list(grid_rows, grid_cols, tile_width)
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_crops_coords_top_left = _get_crops_coords_list(grid_rows, grid_cols, tile_width)
+
+ # update height and width tile size and tile overlap size
+ height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
+ width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ text_embeddings = [
+ [
+ self.encode_prompt(
+ prompt=col,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ for col in row
+ ]
+ for row in prompt
+ ]
+
+ # 3. Prepare latents
+ latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
+ dtype = text_embeddings[0][0][0].dtype
+ latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
+
+ # 3.1 overwrite latents for specific tiles if provided
+ if seed_tiles is not None:
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ if (seed_tile := seed_tiles[row][col]) is not None:
+ mode = seed_tiles_mode[row][col]
+ if mode == self.SeedTilesMode.FULL.value:
+ row_init, row_end, col_init, col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ else:
+ row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(
+ row,
+ col,
+ tile_width,
+ tile_height,
+ tile_row_overlap,
+ tile_col_overlap,
+ grid_rows,
+ grid_cols,
+ )
+ tile_generator = torch.Generator(device).manual_seed(seed_tile)
+ tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
+ latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
+ tile_shape, generator=tile_generator, device=device
+ )
+
+ # 3.2 overwrite again for seed reroll regions
+ for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions:
+ row_init, row_end, col_init, col_end = _pixel2latent_indices(
+ row_init, row_end, col_init, col_end
+ ) # to latent space coordinates
+ reroll_generator = torch.Generator(device).manual_seed(seed_reroll)
+ region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
+ latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
+ region_shape, generator=reroll_generator, device=device
+ )
+
+ # 4. Prepare timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs
+ )
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6. Prepare added time ids & embeddings
+ # text_embeddings order: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+ embeddings_and_added_time = []
+ for row in range(grid_rows):
+ addition_embed_type_row = []
+ for col in range(grid_cols):
+ # extract generated values
+ prompt_embeds = text_embeddings[row][col][0]
+ negative_prompt_embeds = text_embeddings[row][col][1]
+ pooled_prompt_embeds = text_embeddings[row][col][2]
+ negative_pooled_prompt_embeds = text_embeddings[row][col][3]
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left[row][col],
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left[row][col],
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
+ embeddings_and_added_time.append(addition_embed_type_row)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7. Mask for tile weights strength
+ tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)
+
+ # 8. Denoising loop
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Diffuse each tile
+ noise_preds = []
+ for row in range(grid_rows):
+ noise_preds_row = []
+ for col in range(grid_cols):
+ if self.interrupt:
+ continue
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else tile_latents
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {
+ "text_embeds": embeddings_and_added_time[row][col][1],
+ "time_ids": embeddings_and_added_time[row][col][2],
+ }
+ with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=embeddings_and_added_time[row][col][0],
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ guidance = (
+ guidance_scale
+ if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None
+ else guidance_scale_tiles[row][col]
+ )
+ noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
+ noise_preds_row.append(noise_pred_tile)
+ noise_preds.append(noise_preds_row)
+
+ # Stitch noise predictions for all tiles
+ noise_pred = torch.zeros(latents.shape, device=device)
+ contributors = torch.zeros(latents.shape, device=device)
+
+ # Add each tile contribution to overall latents
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (
+ noise_preds[row][col] * tile_weights
+ )
+ contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights
+
+ # Average overlapping areas with more than 1 contributor
+ noise_pred /= contributors
+ noise_pred = noise_pred.to(dtype)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ elif latents.dtype != self.vae.dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ self.vae = self.vae.to(latents.dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/examples/community/mod_controlnet_tile_sr_sdxl.py b/examples/community/mod_controlnet_tile_sr_sdxl.py
new file mode 100644
index 000000000000..80bed2365d9f
--- /dev/null
+++ b/examples/community/mod_controlnet_tile_sr_sdxl.py
@@ -0,0 +1,1862 @@
+# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from transformers import (
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import (
+ AutoencoderKL,
+ ControlNetModel,
+ ControlNetUnionModel,
+ MultiControlNetModel,
+ UNet2DConditionModel,
+)
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.import_utils import is_invisible_watermark_available
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+from diffusers.utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline, ControlNetUnionModel, AutoencoderKL, UniPCMultistepScheduler
+ from diffusers.utils import load_image
+ from PIL import Image
+
+ device = "cuda"
+
+ # Initialize the models and pipeline
+ controlnet = ControlNetUnionModel.from_pretrained(
+ "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
+ ).to(device=device)
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
+
+ model_id = "SG161222/RealVisXL_V5.0"
+ pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
+ model_id, controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
+ ).to(device)
+
+ pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
+ pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
+ pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
+
+ # Set selected scheduler
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+ # Load image
+ control_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1.jpg")
+ original_height = control_image.height
+ original_width = control_image.width
+ print(f"Current resolution: H:{original_height} x W:{original_width}")
+
+ # Pre-upscale image for tiling
+ resolution = 4096
+ tile_gaussian_sigma = 0.3
+ max_tile_size = 1024 # or 1280
+
+ current_size = max(control_image.size)
+ scale_factor = max(2, resolution / current_size)
+ new_size = (int(control_image.width * scale_factor), int(control_image.height * scale_factor))
+ image = control_image.resize(new_size, Image.LANCZOS)
+
+ # Update target height and width
+ target_height = image.height
+ target_width = image.width
+ print(f"Target resolution: H:{target_height} x W:{target_width}")
+
+ # Calculate overlap size
+ normal_tile_overlap, border_tile_overlap = calculate_overlap(target_width, target_height)
+
+ # Set other params
+ tile_weighting_method = TileWeightingMethod.COSINE.value
+ guidance_scale = 4
+ num_inference_steps = 35
+ denoising_strenght = 0.65
+ controlnet_strength = 1.0
+ prompt = "high-quality, noise-free edges, high quality, 4k, hd, 8k"
+ negative_prompt = "blurry, pixelated, noisy, low resolution, artifacts, poor details"
+
+ # Image generation
+ control_image = pipe(
+ image=image,
+ control_image=control_image,
+ control_mode=[6],
+ controlnet_conditioning_scale=float(controlnet_strength),
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ normal_tile_overlap=normal_tile_overlap,
+ border_tile_overlap=border_tile_overlap,
+ height=target_height,
+ width=target_width,
+ original_size=(original_width, original_height),
+ target_size=(target_width, target_height),
+ guidance_scale=guidance_scale,
+ strength=float(denoising_strenght),
+ tile_weighting_method=tile_weighting_method,
+ max_tile_size=max_tile_size,
+ tile_gaussian_sigma=float(tile_gaussian_sigma),
+ num_inference_steps=num_inference_steps,
+ )["images"][0]
+ ```
+"""
+
+
+# This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
+def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280):
+ """
+ Calculate the adaptive tile size based on the image dimensions, ensuring the tile
+ respects the aspect ratio and stays within the specified size limits.
+ """
+ width, height = image_size
+ aspect_ratio = width / height
+
+ if aspect_ratio > 1:
+ # Landscape orientation
+ tile_width = min(width, max_tile_size)
+ tile_height = min(int(tile_width / aspect_ratio), max_tile_size)
+ else:
+ # Portrait or square orientation
+ tile_height = min(height, max_tile_size)
+ tile_width = min(int(tile_height * aspect_ratio), max_tile_size)
+
+ # Ensure the tile size is not smaller than the base_tile_size
+ tile_width = max(tile_width, base_tile_size)
+ tile_height = max(tile_height, base_tile_size)
+
+ return tile_width, tile_height
+
+
+# Copied and adapted from https://github.com/huggingface/diffusers/blob/main/examples/community/mixture_tiling.py
+def _tile2pixel_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height
+):
+ """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in pixel space
+ - Ending coordinates of rows in pixel space
+ - Starting coordinates of columns in pixel space
+ - Ending coordinates of columns in pixel space
+ """
+ # Calculate initial indices
+ px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)
+ px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)
+
+ # Calculate end indices
+ px_row_end = px_row_init + tile_height
+ px_col_end = px_col_init + tile_width
+
+ # Ensure the last tile does not exceed the image dimensions
+ px_row_end = min(px_row_end, image_height)
+ px_col_end = min(px_col_end, image_width)
+
+ return px_row_init, px_row_end, px_col_init, px_col_end
+
+
+# Copied and adapted from https://github.com/huggingface/diffusers/blob/main/examples/community/mixture_tiling.py
+def _tile2latent_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height
+):
+ """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in latent space
+ - Ending coordinates of rows in latent space
+ - Starting coordinates of columns in latent space
+ - Ending coordinates of columns in latent space
+ """
+ # Get pixel indices
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height
+ )
+
+ # Convert to latent space
+ latent_row_init = px_row_init // 8
+ latent_row_end = px_row_end // 8
+ latent_col_init = px_col_init // 8
+ latent_col_end = px_col_end // 8
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+
+ # Ensure the last tile does not exceed the latent dimensions
+ latent_row_end = min(latent_row_end, latent_height)
+ latent_col_end = min(latent_col_end, latent_width)
+
+ return latent_row_init, latent_row_end, latent_col_init, latent_col_end
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class StableDiffusionXLControlNetTileSRPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ controlnet ([`ControlNetUnionModel`]):
+ Provides additional conditioning to the unet during the denoising process.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the
+ config of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: ControlNetUnionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ if not isinstance(controlnet, ControlNetUnionModel):
+ raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+
+ def calculate_overlap(self, width, height, base_overlap=128):
+ """
+ Calculates dynamic overlap based on the image's aspect ratio.
+
+ Args:
+ width (int): Width of the image in pixels.
+ height (int): Height of the image in pixels.
+ base_overlap (int, optional): Base overlap value in pixels. Defaults to 128.
+
+ Returns:
+ tuple: A tuple containing:
+ - row_overlap (int): Overlap between tiles in consecutive rows.
+ - col_overlap (int): Overlap between tiles in consecutive columns.
+ """
+ ratio = height / width
+ if ratio < 1: # Image is wider than tall
+ return base_overlap // 2, base_overlap
+ else: # Image is taller than wide
+ return base_overlap, base_overlap * 2
+
+ class TileWeightingMethod(Enum):
+ """Mode in which the tile weights will be generated"""
+
+ COSINE = "Cosine"
+ GAUSSIAN = "Gaussian"
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ dtype = text_encoders[0].dtype
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_encoder.to(dtype)
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ image,
+ strength,
+ num_inference_steps,
+ normal_tile_overlap,
+ border_tile_overlap,
+ max_tile_size,
+ tile_gaussian_sigma,
+ tile_weighting_method,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ if num_inference_steps is None:
+ raise ValueError("`num_inference_steps` cannot be None.")
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
+ raise ValueError(
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
+ f" {type(num_inference_steps)}."
+ )
+ if normal_tile_overlap is None:
+ raise ValueError("`normal_tile_overlap` cannot be None.")
+ elif not isinstance(normal_tile_overlap, int) or normal_tile_overlap < 64:
+ raise ValueError(
+ f"`normal_tile_overlap` has to be greater than 64 but is {normal_tile_overlap} of type"
+ f" {type(normal_tile_overlap)}."
+ )
+ if border_tile_overlap is None:
+ raise ValueError("`border_tile_overlap` cannot be None.")
+ elif not isinstance(border_tile_overlap, int) or border_tile_overlap < 128:
+ raise ValueError(
+ f"`border_tile_overlap` has to be greater than 128 but is {border_tile_overlap} of type"
+ f" {type(border_tile_overlap)}."
+ )
+ if max_tile_size is None:
+ raise ValueError("`max_tile_size` cannot be None.")
+ elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280):
+ raise ValueError(
+ f"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type" f" {type(max_tile_size)}."
+ )
+ if tile_gaussian_sigma is None:
+ raise ValueError("`tile_gaussian_sigma` cannot be None.")
+ elif not isinstance(tile_gaussian_sigma, float) or tile_gaussian_sigma <= 0:
+ raise ValueError(
+ f"`tile_gaussian_sigma` has to be a positive float but is {tile_gaussian_sigma} of type"
+ f" {type(tile_gaussian_sigma)}."
+ )
+ if tile_weighting_method is None:
+ raise ValueError("`tile_weighting_method` cannot be None.")
+ elif not isinstance(tile_weighting_method, str) or tile_weighting_method not in [
+ t.value for t in self.TileWeightingMethod
+ ]:
+ raise ValueError(
+ f"`tile_weighting_method` has to be a string in ({[t.value for t in self.TileWeightingMethod]}) but is {tile_weighting_method} of type"
+ f" {type(tile_weighting_method)}."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt)
+ elif (
+ isinstance(self.controlnet, ControlNetUnionModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
+ ):
+ self.check_image(image, prompt)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetUnionModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
+ ) or (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
+ def check_image(self, image, prompt):
+ image_is_pil = isinstance(image, Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
+ def prepare_latents(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
+ ):
+ if not isinstance(image, (torch.Tensor, Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ latents_mean = latents_std = None
+ if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
+
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.text_encoder_2.to("cpu")
+ torch.cuda.empty_cache()
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
+ latents_std = latents_std.to(device=device, dtype=dtype)
+ init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
+ else:
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype):
+ """
+ Generates cosine weights as a PyTorch tensor for blending tiles.
+
+ Args:
+ tile_width (int): Width of the tile in pixels.
+ tile_height (int): Height of the tile in pixels.
+ nbatches (int): Number of batches.
+ device (torch.device): Device where the tensor will be allocated (e.g., 'cuda' or 'cpu').
+ dtype (torch.dtype): Data type of the tensor (e.g., torch.float32).
+
+ Returns:
+ torch.Tensor: A tensor containing cosine weights for blending tiles, expanded to match batch and channel dimensions.
+ """
+ # Convert tile dimensions to latent space
+ latent_width = tile_width // 8
+ latent_height = tile_height // 8
+
+ # Generate x and y coordinates in latent space
+ x = np.arange(0, latent_width)
+ y = np.arange(0, latent_height)
+
+ # Calculate midpoints
+ midpoint_x = (latent_width - 1) / 2
+ midpoint_y = (latent_height - 1) / 2
+
+ # Compute cosine probabilities for x and y
+ x_probs = np.cos(np.pi * (x - midpoint_x) / latent_width)
+ y_probs = np.cos(np.pi * (y - midpoint_y) / latent_height)
+
+ # Create a 2D weight matrix using the outer product
+ weights_np = np.outer(y_probs, x_probs)
+
+ # Convert to a PyTorch tensor with the correct device and dtype
+ weights_torch = torch.tensor(weights_np, device=device, dtype=dtype)
+
+ # Expand for batch and channel dimensions
+ tile_weights_expanded = torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
+
+ return tile_weights_expanded
+
+ def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.05):
+ """
+ Generates Gaussian weights as a PyTorch tensor for blending tiles in latent space.
+
+ Args:
+ tile_width (int): Width of the tile in pixels.
+ tile_height (int): Height of the tile in pixels.
+ nbatches (int): Number of batches.
+ device (torch.device): Device where the tensor will be allocated (e.g., 'cuda' or 'cpu').
+ dtype (torch.dtype): Data type of the tensor (e.g., torch.float32).
+ sigma (float, optional): Standard deviation of the Gaussian distribution. Controls the smoothness of the weights. Defaults to 0.05.
+
+ Returns:
+ torch.Tensor: A tensor containing Gaussian weights for blending tiles, expanded to match batch and channel dimensions.
+ """
+ # Convert tile dimensions to latent space
+ latent_width = tile_width // 8
+ latent_height = tile_height // 8
+
+ # Generate Gaussian weights in latent space
+ x = np.linspace(-1, 1, latent_width)
+ y = np.linspace(-1, 1, latent_height)
+ xx, yy = np.meshgrid(x, y)
+ gaussian_weight = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
+
+ # Convert to a PyTorch tensor with the correct device and dtype
+ weights_torch = torch.tensor(gaussian_weight, device=device, dtype=dtype)
+
+ # Expand for batch and channel dimensions
+ weights_expanded = weights_torch.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
+ weights_expanded = weights_expanded.expand(nbatches, -1, -1, -1) # Expand to the number of batches
+
+ return weights_expanded
+
+ def _get_num_tiles(self, height, width, tile_height, tile_width, normal_tile_overlap, border_tile_overlap):
+ """
+ Calculates the number of tiles needed to cover an image, choosing the appropriate formula based on the
+ ratio between the image size and the tile size.
+
+ This function automatically selects between two formulas:
+ 1. A universal formula for typical cases (image-to-tile ratio <= 6:1).
+ 2. A specialized formula with border tile overlap for larger or atypical cases (image-to-tile ratio > 6:1).
+
+ Args:
+ height (int): Height of the image in pixels.
+ width (int): Width of the image in pixels.
+ tile_height (int): Height of each tile in pixels.
+ tile_width (int): Width of each tile in pixels.
+ normal_tile_overlap (int): Overlap between tiles in pixels for normal (non-border) tiles.
+ border_tile_overlap (int): Overlap between tiles in pixels for border tiles.
+
+ Returns:
+ tuple: A tuple containing:
+ - grid_rows (int): Number of rows in the tile grid.
+ - grid_cols (int): Number of columns in the tile grid.
+
+ Notes:
+ - The function uses the universal formula (without border_tile_overlap) for typical cases where the
+ image-to-tile ratio is 6:1 or smaller.
+ - For larger or atypical cases (image-to-tile ratio > 6:1), it uses a specialized formula that includes
+ border_tile_overlap to ensure complete coverage of the image, especially at the edges.
+ """
+ # Calculate the ratio between the image size and the tile size
+ height_ratio = height / tile_height
+ width_ratio = width / tile_width
+
+ # If the ratio is greater than 6:1, use the formula with border_tile_overlap
+ if height_ratio > 6 or width_ratio > 6:
+ grid_rows = int(np.ceil((height - border_tile_overlap) / (tile_height - normal_tile_overlap))) + 1
+ grid_cols = int(np.ceil((width - border_tile_overlap) / (tile_width - normal_tile_overlap))) + 1
+ else:
+ # Otherwise, use the universal formula
+ grid_rows = int(np.ceil((height - normal_tile_overlap) / (tile_height - normal_tile_overlap)))
+ grid_cols = int(np.ceil((width - normal_tile_overlap) / (tile_width - normal_tile_overlap)))
+
+ return grid_rows, grid_cols
+
+ def prepare_tiles(
+ self,
+ grid_rows,
+ grid_cols,
+ tile_weighting_method,
+ tile_width,
+ tile_height,
+ normal_tile_overlap,
+ border_tile_overlap,
+ width,
+ height,
+ tile_sigma,
+ batch_size,
+ device,
+ dtype,
+ ):
+ """
+ Processes image tiles by dynamically adjusting overlap and calculating Gaussian or cosine weights.
+
+ Args:
+ grid_rows (int): Number of rows in the tile grid.
+ grid_cols (int): Number of columns in the tile grid.
+ tile_weighting_method (str): Method for weighting tiles. Options: "Gaussian" or "Cosine".
+ tile_width (int): Width of each tile in pixels.
+ tile_height (int): Height of each tile in pixels.
+ normal_tile_overlap (int): Overlap between tiles in pixels for normal tiles.
+ border_tile_overlap (int): Overlap between tiles in pixels for border tiles.
+ width (int): Width of the image in pixels.
+ height (int): Height of the image in pixels.
+ tile_sigma (float): Sigma parameter for Gaussian weighting.
+ batch_size (int): Batch size for weight tiles.
+ device (torch.device): Device where tensors will be allocated (e.g., 'cuda' or 'cpu').
+ dtype (torch.dtype): Data type of the tensors (e.g., torch.float32).
+
+ Returns:
+ tuple: A tuple containing:
+ - tile_weights (np.ndarray): Array of weights for each tile.
+ - tile_row_overlaps (np.ndarray): Array of row overlaps for each tile.
+ - tile_col_overlaps (np.ndarray): Array of column overlaps for each tile.
+ """
+
+ # Create arrays to store dynamic overlaps and weights
+ tile_row_overlaps = np.full((grid_rows, grid_cols), normal_tile_overlap)
+ tile_col_overlaps = np.full((grid_rows, grid_cols), normal_tile_overlap)
+ tile_weights = np.empty((grid_rows, grid_cols), dtype=object) # Stores Gaussian or cosine weights
+
+ # Iterate over tiles to adjust overlap and calculate weights
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ # Calculate the size of the current tile
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
+ row, col, tile_width, tile_height, normal_tile_overlap, normal_tile_overlap, width, height
+ )
+ current_tile_width = px_col_end - px_col_init
+ current_tile_height = px_row_end - px_row_init
+ sigma = tile_sigma
+
+ # Adjust overlap for smaller tiles
+ if current_tile_width < tile_width:
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
+ row, col, tile_width, tile_height, border_tile_overlap, border_tile_overlap, width, height
+ )
+ current_tile_width = px_col_end - px_col_init
+ tile_col_overlaps[row, col] = border_tile_overlap
+ sigma = tile_sigma * 1.2
+ if current_tile_height < tile_height:
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
+ row, col, tile_width, tile_height, border_tile_overlap, border_tile_overlap, width, height
+ )
+ current_tile_height = px_row_end - px_row_init
+ tile_row_overlaps[row, col] = border_tile_overlap
+ sigma = tile_sigma * 1.2
+
+ # Calculate weights for the current tile
+ if tile_weighting_method == self.TileWeightingMethod.COSINE.value:
+ tile_weights[row, col] = self._generate_cosine_weights(
+ tile_width=current_tile_width,
+ tile_height=current_tile_height,
+ nbatches=batch_size,
+ device=device,
+ dtype=torch.float32,
+ )
+ else:
+ tile_weights[row, col] = self._generate_gaussian_weights(
+ tile_width=current_tile_width,
+ tile_height=current_tile_height,
+ nbatches=batch_size,
+ device=device,
+ dtype=dtype,
+ sigma=sigma,
+ )
+
+ return tile_weights, tile_row_overlaps, tile_col_overlaps
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_mode: Optional[Union[int, List[int]]] = None,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ normal_tile_overlap: int = 64,
+ border_tile_overlap: int = 128,
+ max_tile_size: int = 1024,
+ tile_gaussian_sigma: float = 0.05,
+ tile_weighting_method: str = "Cosine",
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, *optional*):
+ The initial image to be used as the starting point for the image generation process. Can also accept
+ image latents as `image`, if passing latents directly, they will not be encoded again.
+ control_image (`PipelineImageInput`, *optional*):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance for Unet.
+ If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
+ be accepted as an image. The dimensions of the output image default to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ init, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ height (`int`, *optional*):
+ The height in pixels of the generated image. If not provided, defaults to the height of `control_image`.
+ width (`int`, *optional*):
+ The width in pixels of the generated image. If not provided, defaults to the width of `control_image`.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Indicates the extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point, and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum, and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
+ Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages generating
+ images closely linked to the text `prompt`, usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/):
+ `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original UNet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ control_mode (`int` or `List[int]`, *optional*):
+ The mode of ControlNet guidance. Can be used to specify different behaviors for multiple ControlNets.
+ original_size (`Tuple[int, int]`, *optional*):
+ If `original_size` is not the same as `target_size`, the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning.
+ crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning.
+ target_size (`Tuple[int, int]`, *optional*):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified, it will default to `(height, width)`. Part of SDXL's micro-conditioning.
+ negative_original_size (`Tuple[int, int]`, *optional*):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning.
+ negative_crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning.
+ negative_target_size (`Tuple[int, int]`, *optional*):
+ To negatively condition the generation process based on a target image resolution. It should be the same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning.
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Used to simulate an aesthetic score of the generated image by influencing the negative text condition.
+ Part of SDXL's micro-conditioning.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ normal_tile_overlap (`int`, *optional*, defaults to 64):
+ Number of overlapping pixels between tiles in consecutive rows.
+ border_tile_overlap (`int`, *optional*, defaults to 128):
+ Number of overlapping pixels between tiles at the borders.
+ max_tile_size (`int`, *optional*, defaults to 1024):
+ Maximum size of a tile in pixels.
+ tile_gaussian_sigma (`float`, *optional*, defaults to 0.3):
+ Sigma parameter for Gaussian weighting of tiles.
+ tile_weighting_method (`str`, *optional*, defaults to "Cosine"):
+ Method for weighting tiles. Options: "Cosine" or "Gaussian".
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
+ containing the output images.
+ """
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+
+ if not isinstance(control_image, list):
+ control_image = [control_image]
+ else:
+ control_image = control_image.copy()
+
+ if control_mode is None or isinstance(control_mode, list) and len(control_mode) == 0:
+ raise ValueError("The value for `control_mode` is expected!")
+
+ if not isinstance(control_mode, list):
+ control_mode = [control_mode]
+
+ if len(control_image) != len(control_mode):
+ raise ValueError("Expected len(control_image) == len(control_mode)")
+
+ num_control_type = controlnet.config.num_control_type
+
+ # 0. Set internal use parameters
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+ negative_original_size = negative_original_size or original_size
+ negative_target_size = negative_target_size or target_size
+ control_type = [0 for _ in range(num_control_type)]
+ control_type = torch.Tensor(control_type)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+ batch_size = 1
+ device = self._execution_device
+ global_pool_conditions = controlnet.config.global_pool_conditions
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 1. Check inputs
+ for _image, control_idx in zip(control_image, control_mode):
+ control_type[control_idx] = 1
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ _image,
+ strength,
+ num_inference_steps,
+ normal_tile_overlap,
+ border_tile_overlap,
+ max_tile_size,
+ tile_gaussian_sigma,
+ tile_weighting_method,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2 Get tile width and tile height size
+ tile_width, tile_height = _adaptive_tile_size((width, height), max_tile_size=max_tile_size)
+
+ # 2.1 Calculate the number of tiles needed
+ grid_rows, grid_cols = self._get_num_tiles(
+ height, width, tile_height, tile_width, normal_tile_overlap, border_tile_overlap
+ )
+
+ # 2.2 Expand prompt to number of tiles
+ if not isinstance(prompt, list):
+ prompt = [[prompt] * grid_cols] * grid_rows
+
+ # 2.3 Update height and width tile size by tile size and tile overlap size
+ width = (grid_cols - 1) * (tile_width - normal_tile_overlap) + min(
+ tile_width, width - (grid_cols - 1) * (tile_width - normal_tile_overlap)
+ )
+ height = (grid_rows - 1) * (tile_height - normal_tile_overlap) + min(
+ tile_height, height - (grid_rows - 1) * (tile_height - normal_tile_overlap)
+ )
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ text_embeddings = [
+ [
+ self.encode_prompt(
+ prompt=col,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ for col in row
+ ]
+ for row in prompt
+ ]
+
+ # 4. Prepare latent image
+ image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+
+ # 4.1 Prepare controlnet_conditioning_image
+ control_image = self.prepare_control_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ control_type = (
+ control_type.reshape(1, -1)
+ .to(device, dtype=controlnet.dtype)
+ .repeat(batch_size * num_images_per_prompt * 2, 1)
+ )
+
+ # 5. Prepare timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+ self.scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ dtype = text_embeddings[0][0][0].dtype
+ if latents is None:
+ latents = self.prepare_latents(
+ image_tensor,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ dtype,
+ device,
+ generator,
+ True,
+ )
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ controlnet_keep.append(
+ 1.0
+ - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
+ )
+
+ # 8.1 Prepare added time ids & embeddings
+ # text_embeddings order: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+ embeddings_and_added_time = []
+ crops_coords_top_left = negative_crops_coords_top_left = (tile_width, tile_height)
+ for row in range(grid_rows):
+ addition_embed_type_row = []
+ for col in range(grid_cols):
+ # extract generated values
+ prompt_embeds = text_embeddings[row][col][0]
+ negative_prompt_embeds = text_embeddings[row][col][1]
+ pooled_prompt_embeds = text_embeddings[row][col][2]
+ negative_pooled_prompt_embeds = text_embeddings[row][col][3]
+
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+ add_text_embeds = pooled_prompt_embeds
+
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+ addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))
+
+ embeddings_and_added_time.append(addition_embed_type_row)
+
+ # 9. Prepare tiles weights and latent overlaps size to denoising process
+ tile_weights, tile_row_overlaps, tile_col_overlaps = self.prepare_tiles(
+ grid_rows,
+ grid_cols,
+ tile_weighting_method,
+ tile_width,
+ tile_height,
+ normal_tile_overlap,
+ border_tile_overlap,
+ width,
+ height,
+ tile_gaussian_sigma,
+ batch_size,
+ device,
+ dtype,
+ )
+
+ # 10. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Diffuse each tile
+ noise_preds = []
+ for row in range(grid_rows):
+ noise_preds_row = []
+ for col in range(grid_cols):
+ if self.interrupt:
+ continue
+ tile_row_overlap = tile_row_overlaps[row, col]
+ tile_col_overlap = tile_col_overlaps[row, col]
+
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height
+ )
+
+ tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([tile_latents] * 2)
+ if self.do_classifier_free_guidance
+ else tile_latents # 1, 4, ...
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {
+ "text_embeds": embeddings_and_added_time[row][col][1],
+ "time_ids": embeddings_and_added_time[row][col][2],
+ }
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = tile_latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = embeddings_and_added_time[row][col][0].chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": embeddings_and_added_time[row][col][1].chunk(2)[1],
+ "time_ids": embeddings_and_added_time[row][col][2].chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = embeddings_and_added_time[row][col][0]
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ px_row_init_pixel, px_row_end_pixel, px_col_init_pixel, px_col_end_pixel = _tile2pixel_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height
+ )
+
+ tile_control_image = control_image[
+ :, :, px_row_init_pixel:px_row_end_pixel, px_col_init_pixel:px_col_end_pixel
+ ]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=[tile_control_image],
+ control_type=control_type,
+ control_type_idx=control_mode,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [
+ torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.cat(
+ [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]
+ )
+
+ # predict the noise residual
+ with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=embeddings_and_added_time[row][col][0],
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_tile = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+ noise_preds_row.append(noise_pred_tile)
+ noise_preds.append(noise_preds_row)
+
+ # Stitch noise predictions for all tiles
+ noise_pred = torch.zeros(latents.shape, device=device)
+ contributors = torch.zeros(latents.shape, device=device)
+
+ # Add each tile contribution to overall latents
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ tile_row_overlap = tile_row_overlaps[row, col]
+ tile_col_overlap = tile_col_overlaps[row, col]
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height
+ )
+ tile_weights_resized = tile_weights[row, col]
+
+ noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (
+ noise_preds[row][col] * tile_weights_resized
+ )
+ contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights_resized
+
+ # Average overlapping areas with more than 1 contributor
+ noise_pred /= contributors
+ noise_pred = noise_pred.to(dtype)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # update progress bar
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ result = StableDiffusionXLPipelineOutput(images=image)
+ if not return_dict:
+ return (image,)
+
+ return result
diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py
index dc335e0b585e..5dcc75c9e20b 100644
--- a/examples/community/multilingual_stable_diffusion.py
+++ b/examples/community/multilingual_stable_diffusion.py
@@ -98,7 +98,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py
index bedf002d024c..9f99ad248be2 100644
--- a/examples/community/pipeline_animatediff_controlnet.py
+++ b/examples/community/pipeline_animatediff_controlnet.py
@@ -188,7 +188,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
diff --git a/examples/community/pipeline_animatediff_img2video.py b/examples/community/pipeline_animatediff_img2video.py
index 0a578d4b8ef6..f7f0cf31c5dd 100644
--- a/examples/community/pipeline_animatediff_img2video.py
+++ b/examples/community/pipeline_animatediff_img2video.py
@@ -308,7 +308,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
diff --git a/examples/community/pipeline_animatediff_ipex.py b/examples/community/pipeline_animatediff_ipex.py
index dc65e76bc43b..06508f217c4c 100644
--- a/examples/community/pipeline_animatediff_ipex.py
+++ b/examples/community/pipeline_animatediff_ipex.py
@@ -162,7 +162,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py
index f83d1b401420..624b2bd1ed81 100644
--- a/examples/community/pipeline_demofusion_sdxl.py
+++ b/examples/community/pipeline_demofusion_sdxl.py
@@ -166,9 +166,13 @@ def __init__(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -290,7 +294,9 @@ def encode_prompt(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -342,7 +348,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_fabric.py b/examples/community/pipeline_fabric.py
index 02fdcd04c103..30847f875bda 100644
--- a/examples/community/pipeline_fabric.py
+++ b/examples/community/pipeline_fabric.py
@@ -150,10 +150,14 @@ def __init__(
):
super().__init__()
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -179,7 +183,7 @@ def __init__(
tokenizer=tokenizer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
diff --git a/examples/community/pipeline_faithdiff_stable_diffusion_xl.py b/examples/community/pipeline_faithdiff_stable_diffusion_xl.py
new file mode 100644
index 000000000000..d1d3d80b4a60
--- /dev/null
+++ b/examples/community/pipeline_faithdiff_stable_diffusion_xl.py
@@ -0,0 +1,2269 @@
+# Copyright 2025 Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab Team
+# and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import inspect
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ PeftAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+ UNet2DConditionLoadersMixin,
+)
+from diffusers.models import AutoencoderKL
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_version,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.outputs import BaseOutput
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import random
+ >>> import numpy as np
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler
+ >>> from huggingface_hub import hf_hub_download
+ >>> from diffusers.utils import load_image
+ >>> from PIL import Image
+ >>>
+ >>> device = "cuda"
+ >>> dtype = torch.float16
+ >>> MAX_SEED = np.iinfo(np.int32).max
+ >>>
+ >>> # Download weights for additional unet layers
+ >>> model_file = hf_hub_download(
+ ... "jychen9811/FaithDiff",
+ ... filename="FaithDiff.bin", local_dir="./proc_data/faithdiff", local_dir_use_symlinks=False
+ ... )
+ >>>
+ >>> # Initialize the models and pipeline
+ >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
+ >>>
+ >>> model_id = "SG161222/RealVisXL_V4.0"
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... model_id,
+ ... torch_dtype=dtype,
+ ... vae=vae,
+ ... unet=None, #<- Do not load with original model.
+ ... custom_pipeline="mixture_tiling_sdxl",
+ ... use_safetensors=True,
+ ... variant="fp16",
+ ... ).to(device)
+ >>>
+ >>> # Here we need use pipeline internal unet model
+ >>> pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
+ >>>
+ >>> # Load aditional layers to the model
+ >>> pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype)
+ >>>
+ >>> # Enable vae tiling
+ >>> pipe.set_encoder_tile_settings()
+ >>> pipe.enable_vae_tiling()
+ >>>
+ >>> # Optimization
+ >>> pipe.enable_model_cpu_offload()
+ >>>
+ >>> # Set selected scheduler
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>>
+ >>> #input params
+ >>> prompt = "The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. "
+ >>> upscale = 2 # scale here
+ >>> start_point = "lr" # or "noise"
+ >>> latent_tiled_overlap = 0.5
+ >>> latent_tiled_size = 1024
+ >>>
+ >>> # Load image
+ >>> lq_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png")
+ >>> original_height = lq_image.height
+ >>> original_width = lq_image.width
+ >>> print(f"Current resolution: H:{original_height} x W:{original_width}")
+ >>>
+ >>> width = original_width * int(upscale)
+ >>> height = original_height * int(upscale)
+ >>> print(f"Final resolution: H:{height} x W:{width}")
+ >>>
+ >>> # Restoration
+ >>> image = lq_image.resize((width, height), Image.LANCZOS)
+ >>> input_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image)
+ >>>
+ >>> generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED))
+ >>> gen_image = pipe(lr_img=input_image,
+ ... prompt = prompt,
+ ... num_inference_steps=20,
+ ... guidance_scale=5,
+ ... generator=generator,
+ ... start_point=start_point,
+ ... height = height_now,
+ ... width=width_now,
+ ... overlap=latent_tiled_overlap,
+ ... target_size=(latent_tiled_size, latent_tiled_size)
+ ... ).images[0]
+ >>>
+ >>> cropped_image = gen_image.crop((0, 0, width_init, height_init))
+ >>> cropped_image.save("data/result.png")
+ ```
+"""
+
+
+def zero_module(module):
+ """Zero out the parameters of a module and return it."""
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+
+class Encoder(nn.Module):
+ """Encoder layer of a variational autoencoder that encodes input into a latent representation."""
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 4,
+ down_block_types: Tuple[str, ...] = (
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ double_z: bool = True,
+ mid_block_add_attention: bool = True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+ self.use_rgb = False
+ self.down_block_type = down_block_types
+ self.block_out_channels = block_out_channels
+
+ self.tile_sample_min_size = 1024
+ self.tile_latent_min_size = int(self.tile_sample_min_size / 8)
+ self.tile_overlap_factor = 0.25
+ self.use_tiling = False
+
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ add_attention=mid_block_add_attention,
+ )
+
+ self.gradient_checkpointing = False
+
+ def to_rgb_init(self):
+ """Initialize layers to convert features to RGB."""
+ self.to_rgbs = nn.ModuleList([])
+ self.use_rgb = True
+ for i, down_block_type in enumerate(self.down_block_type):
+ output_channel = self.block_out_channels[i]
+ self.to_rgbs.append(nn.Conv2d(output_channel, 3, kernel_size=3, padding=1))
+
+ def enable_tiling(self):
+ """Enable tiling for large inputs."""
+ self.use_tiling = True
+
+ def encode(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """Encode the input tensor into a latent representation."""
+ sample = self.conv_in(sample)
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ for down_block in self.down_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(down_block), sample, use_reentrant=False
+ )
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
+ )
+ else:
+ for down_block in self.down_blocks:
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
+ return sample
+ else:
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+ sample = self.mid_block(sample)
+ return sample
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ """Blend two tensors vertically with a smooth transition."""
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ """Blend two tensors horizontally with a smooth transition."""
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ """Encode the input tensor using tiling for large inputs."""
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encode(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ moments = torch.cat(result_rows, dim=2)
+ return moments
+
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """Forward pass of the encoder, using tiling if enabled for large inputs."""
+ if self.use_tiling and (
+ sample.shape[-1] > self.tile_latent_min_size or sample.shape[-2] > self.tile_latent_min_size
+ ):
+ return self.tiled_encode(sample)
+ return self.encode(sample)
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """A small network to preprocess conditioning inputs, inspired by ControlNet."""
+
+ def __init__(self, conditioning_embedding_channels: int, conditioning_channels: int = 4):
+ super().__init__()
+ self.conv_in = nn.Conv2d(conditioning_channels, conditioning_channels, kernel_size=3, padding=1)
+ self.norm_in = nn.GroupNorm(num_channels=conditioning_channels, num_groups=32, eps=1e-6)
+ self.conv_out = zero_module(
+ nn.Conv2d(conditioning_channels, conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ """Process the conditioning input through the network."""
+ conditioning = self.norm_in(conditioning)
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+ embedding = self.conv_out(embedding)
+ return embedding
+
+
+class QuickGELU(nn.Module):
+ """A fast approximation of the GELU activation function."""
+
+ def forward(self, x: torch.Tensor):
+ """Apply the QuickGELU activation to the input tensor."""
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ """Apply LayerNorm and preserve the input dtype."""
+ orig_type = x.dtype
+ ret = super().forward(x)
+ return ret.type(orig_type)
+
+
+class ResidualAttentionBlock(nn.Module):
+ """A transformer-style block with self-attention and an MLP."""
+
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict(
+ [
+ ("c_fc", nn.Linear(d_model, d_model * 2)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 2, d_model)),
+ ]
+ )
+ )
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ """Apply self-attention to the input tensor."""
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ """Forward pass through the residual attention block."""
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """The output of UnifiedUNet2DConditionModel."""
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+ """A unified 2D UNet model extending OriginalUNet2DConditionModel with custom functionality."""
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ """Initialize the UnifiedUNet2DConditionModel."""
+ super().__init__(
+ sample_size=sample_size,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ center_input_sample=center_input_sample,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ down_block_types=down_block_types,
+ mid_block_type=mid_block_type,
+ up_block_types=up_block_types,
+ only_cross_attention=only_cross_attention,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ downsample_padding=downsample_padding,
+ mid_block_scale_factor=mid_block_scale_factor,
+ dropout=dropout,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ cross_attention_dim=cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ class_embed_type=class_embed_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ num_class_embeds=num_class_embeds,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ time_embedding_type=time_embedding_type,
+ time_embedding_dim=time_embedding_dim,
+ time_embedding_act_fn=time_embedding_act_fn,
+ timestep_post_act=timestep_post_act,
+ time_cond_proj_dim=time_cond_proj_dim,
+ conv_in_kernel=conv_in_kernel,
+ conv_out_kernel=conv_out_kernel,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ attention_type=attention_type,
+ class_embeddings_concat=class_embeddings_concat,
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
+ )
+
+ # Additional attributes
+ self.denoise_encoder = None
+ self.information_transformer_layes = None
+ self.condition_embedding = None
+ self.agg_net = None
+ self.spatial_ch_projs = None
+
+ def init_vae_encoder(self, dtype):
+ self.denoise_encoder = Encoder()
+ if dtype is not None:
+ self.denoise_encoder.dtype = dtype
+
+ def init_information_transformer_layes(self):
+ num_trans_channel = 640
+ num_trans_head = 8
+ num_trans_layer = 2
+ num_proj_channel = 320
+ self.information_transformer_layes = nn.Sequential(
+ *[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)]
+ )
+ self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
+
+ def init_ControlNetConditioningEmbedding(self, channel=512):
+ self.condition_embedding = ControlNetConditioningEmbedding(320, channel)
+
+ def init_extra_weights(self):
+ self.agg_net = nn.ModuleList()
+
+ def load_additional_layers(
+ self, dtype: Optional[torch.dtype] = torch.float16, channel: int = 512, weight_path: Optional[str] = None
+ ):
+ """Load additional layers and weights from a file.
+
+ Args:
+ weight_path (str): Path to the weight file.
+ dtype (torch.dtype, optional): Data type for the loaded weights. Defaults to torch.float16.
+ channel (int): Conditioning embedding channel out size. Defaults 512.
+ """
+ if self.denoise_encoder is None:
+ self.init_vae_encoder(dtype)
+
+ if self.information_transformer_layes is None:
+ self.init_information_transformer_layes()
+
+ if self.condition_embedding is None:
+ self.init_ControlNetConditioningEmbedding(channel)
+
+ if self.agg_net is None:
+ self.init_extra_weights()
+
+ # Load weights if provided
+ if weight_path is not None:
+ state_dict = torch.load(weight_path, weights_only=False)
+ self.load_state_dict(state_dict, strict=True)
+
+ # Move all modules to the same device and dtype as the model
+ device = next(self.parameters()).device
+ if dtype is not None or device is not None:
+ self.to(device=device, dtype=dtype or next(self.parameters()).dtype)
+
+ def to(self, *args, **kwargs):
+ """Override to() to move all additional modules to the same device and dtype."""
+ super().to(*args, **kwargs)
+ for module in [
+ self.denoise_encoder,
+ self.information_transformer_layes,
+ self.condition_embedding,
+ self.agg_net,
+ self.spatial_ch_projs,
+ ]:
+ if module is not None:
+ module.to(*args, **kwargs)
+ return self
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Load state dictionary into the model.
+
+ Args:
+ state_dict (dict): State dictionary to load.
+ strict (bool, optional): Whether to strictly enforce that all keys match. Defaults to True.
+ """
+ core_dict = {}
+ additional_dicts = {
+ "denoise_encoder": {},
+ "information_transformer_layes": {},
+ "condition_embedding": {},
+ "agg_net": {},
+ "spatial_ch_projs": {},
+ }
+
+ for key, value in state_dict.items():
+ if key.startswith("denoise_encoder."):
+ additional_dicts["denoise_encoder"][key[len("denoise_encoder.") :]] = value
+ elif key.startswith("information_transformer_layes."):
+ additional_dicts["information_transformer_layes"][key[len("information_transformer_layes.") :]] = value
+ elif key.startswith("condition_embedding."):
+ additional_dicts["condition_embedding"][key[len("condition_embedding.") :]] = value
+ elif key.startswith("agg_net."):
+ additional_dicts["agg_net"][key[len("agg_net.") :]] = value
+ elif key.startswith("spatial_ch_projs."):
+ additional_dicts["spatial_ch_projs"][key[len("spatial_ch_projs.") :]] = value
+ else:
+ core_dict[key] = value
+
+ super().load_state_dict(core_dict, strict=False)
+ for module_name, module_dict in additional_dicts.items():
+ module = getattr(self, module_name, None)
+ if module is not None and module_dict:
+ module.load_state_dict(module_dict, strict=strict)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ input_embedding: Optional[torch.Tensor] = None,
+ add_sample: bool = True,
+ return_dict: bool = True,
+ use_condition_embedding: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ """Forward pass prioritizing the original modified implementation.
+
+ Args:
+ sample (torch.FloatTensor): The noisy input tensor with shape `(batch, channel, height, width)`.
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
+ encoder_hidden_states (torch.Tensor): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (torch.Tensor, optional): Optional class labels for conditioning.
+ timestep_cond (torch.Tensor, optional): Conditional embeddings for timestep.
+ attention_mask (torch.Tensor, optional): An attention mask of shape `(batch, key_tokens)`.
+ cross_attention_kwargs (Dict[str, Any], optional): A kwargs dictionary for the AttentionProcessor.
+ added_cond_kwargs (Dict[str, torch.Tensor], optional): Additional embeddings to add to the UNet blocks.
+ down_block_additional_residuals (Tuple[torch.Tensor], optional): Residuals for down UNet blocks.
+ mid_block_additional_residual (torch.Tensor, optional): Residual for the middle UNet block.
+ down_intrablock_additional_residuals (Tuple[torch.Tensor], optional): Additional residuals within down blocks.
+ encoder_attention_mask (torch.Tensor, optional): A cross-attention mask of shape `(batch, sequence_length)`.
+ input_embedding (torch.Tensor, optional): Additional input embedding for preprocessing.
+ add_sample (bool): Whether to add the sample to the processed embedding. Defaults to True.
+ return_dict (bool): Whether to return a UNet2DConditionOutput. Defaults to True.
+ use_condition_embedding (bool): Whether to use the condition embedding. Defaults to True.
+
+ Returns:
+ Union[UNet2DConditionOutput, Tuple]: The processed sample tensor, either as a UNet2DConditionOutput or tuple.
+ """
+ default_overall_up_factor = 2**self.num_upsamplers
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ forward_upsample_size = True
+ break
+
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
+ if class_emb is not None:
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ aug_emb = self.get_aug_embed(
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+ if self.config.addition_embed_type == "image_hint":
+ aug_emb, hint = aug_emb
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ # 2. pre-process (following the original modified logic)
+ sample = self.conv_in(sample) # [B, 4, H, W] -> [B, 320, H, W]
+ if (
+ input_embedding is not None
+ and self.condition_embedding is not None
+ and self.information_transformer_layes is not None
+ ):
+ if use_condition_embedding:
+ input_embedding = self.condition_embedding(input_embedding) # [B, 320, H, W]
+ batch_size, channel, height, width = input_embedding.shape
+ concat_feat = (
+ torch.cat([sample, input_embedding], dim=1)
+ .view(batch_size, 2 * channel, height * width)
+ .transpose(1, 2)
+ )
+ concat_feat = self.information_transformer_layes(concat_feat)
+ feat_alpha = self.spatial_ch_projs(concat_feat).transpose(1, 2).view(batch_size, channel, height, width)
+ sample = sample + feat_alpha if add_sample else feat_alpha # Update sample as in the original version
+
+ # 2.5 GLIGEN position net (kept from the original version)
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down (continues the standard flow)
+ if cross_attention_kwargs is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = down_intrablock_additional_residuals is not None
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+ return UNet2DConditionOutput(sample=sample)
+
+
+class LocalAttention:
+ """A class to handle local attention by splitting tensors into overlapping grids for processing."""
+
+ def __init__(self, kernel_size=None, overlap=0.5):
+ """Initialize the LocalAttention module.
+
+ Args:
+ kernel_size (tuple[int, int], optional): Size of the grid (height, width). Defaults to None.
+ overlap (float): Overlap factor between adjacent grids (0.0 to 1.0). Defaults to 0.5.
+ """
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.overlap = overlap
+
+ def grids_list(self, x):
+ """Split the input tensor into a list of non-overlapping grid patches.
+
+ Args:
+ x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
+
+ Returns:
+ list[torch.Tensor]: List of tensor patches.
+ """
+ b, c, h, w = x.shape
+ self.original_size = (b, c, h, w)
+ assert b == 1
+ k1, k2 = self.kernel_size
+ if h < k1:
+ k1 = h
+ if w < k2:
+ k2 = w
+ num_row = (h - 1) // k1 + 1
+ num_col = (w - 1) // k2 + 1
+ self.nr = num_row
+ self.nc = num_col
+
+ import math
+
+ step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap)
+ step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap)
+ parts = []
+ idxes = []
+ i = 0
+ last_i = False
+ while i < h and not last_i:
+ j = 0
+ if i + k1 >= h:
+ i = h - k1
+ last_i = True
+ last_j = False
+ while j < w and not last_j:
+ if j + k2 >= w:
+ j = w - k2
+ last_j = True
+ parts.append(x[:, :, i : i + k1, j : j + k2])
+ idxes.append({"i": i, "j": j})
+ j = j + step_j
+ i = i + step_i
+ return parts
+
+ def grids(self, x):
+ """Split the input tensor into overlapping grid patches and concatenate them.
+
+ Args:
+ x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
+
+ Returns:
+ torch.Tensor: Concatenated tensor of all grid patches.
+ """
+ b, c, h, w = x.shape
+ self.original_size = (b, c, h, w)
+ assert b == 1
+ k1, k2 = self.kernel_size
+ if h < k1:
+ k1 = h
+ if w < k2:
+ k2 = w
+ self.tile_weights = self._gaussian_weights(k2, k1)
+ num_row = (h - 1) // k1 + 1
+ num_col = (w - 1) // k2 + 1
+ self.nr = num_row
+ self.nc = num_col
+
+ import math
+
+ step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap)
+ step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap)
+ parts = []
+ idxes = []
+ i = 0
+ last_i = False
+ while i < h and not last_i:
+ j = 0
+ if i + k1 >= h:
+ i = h - k1
+ last_i = True
+ last_j = False
+ while j < w and not last_j:
+ if j + k2 >= w:
+ j = w - k2
+ last_j = True
+ parts.append(x[:, :, i : i + k1, j : j + k2])
+ idxes.append({"i": i, "j": j})
+ j = j + step_j
+ i = i + step_i
+ self.idxes = idxes
+ return torch.cat(parts, dim=0)
+
+ def _gaussian_weights(self, tile_width, tile_height):
+ """Generate a Gaussian weight mask for tile contributions.
+
+ Args:
+ tile_width (int): Width of the tile.
+ tile_height (int): Height of the tile.
+
+ Returns:
+ torch.Tensor: Gaussian weight tensor of shape (channels, height, width).
+ """
+ import numpy as np
+ from numpy import exp, pi, sqrt
+
+ latent_width = tile_width
+ latent_height = tile_height
+ var = 0.01
+ midpoint = (latent_width - 1) / 2
+ x_probs = [
+ exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
+ for x in range(latent_width)
+ ]
+ midpoint = latent_height / 2
+ y_probs = [
+ exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
+ for y in range(latent_height)
+ ]
+ weights = np.outer(y_probs, x_probs)
+ return torch.tile(torch.tensor(weights, device=torch.device("cuda")), (4, 1, 1))
+
+ def grids_inverse(self, outs):
+ """Reconstruct the original tensor from processed grid patches with overlap blending.
+
+ Args:
+ outs (torch.Tensor): Processed grid patches.
+
+ Returns:
+ torch.Tensor: Reconstructed tensor of original size.
+ """
+ preds = torch.zeros(self.original_size).to(outs.device)
+ b, c, h, w = self.original_size
+ count_mt = torch.zeros((b, 4, h, w)).to(outs.device)
+ k1, k2 = self.kernel_size
+
+ for cnt, each_idx in enumerate(self.idxes):
+ i = each_idx["i"]
+ j = each_idx["j"]
+ preds[0, :, i : i + k1, j : j + k2] += outs[cnt, :, :, :] * self.tile_weights
+ count_mt[0, :, i : i + k1, j : j + k2] += self.tile_weights
+
+ del outs
+ torch.cuda.empty_cache()
+ return preds / count_mt
+
+ def _pad(self, x):
+ """Pad the input tensor to align with kernel size.
+
+ Args:
+ x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
+
+ Returns:
+ tuple: Padded tensor and padding values.
+ """
+ b, c, h, w = x.shape
+ k1, k2 = self.kernel_size
+ mod_pad_h = (k1 - h % k1) % k1
+ mod_pad_w = (k2 - w % k2) % k2
+ pad = (mod_pad_w // 2, mod_pad_w - mod_pad_w // 2, mod_pad_h // 2, mod_pad_h - mod_pad_h // 2)
+ x = F.pad(x, pad, "reflect")
+ return x, pad
+
+ def forward(self, x):
+ """Apply local attention by splitting into grids and reconstructing.
+
+ Args:
+ x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
+
+ Returns:
+ torch.Tensor: Processed tensor of original size.
+ """
+ b, c, h, w = x.shape
+ qkv = self.grids(x)
+ out = self.grids_inverse(qkv)
+ return out
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+
+ Args:
+ noise_cfg (torch.Tensor): Noise configuration tensor.
+ noise_pred_text (torch.Tensor): Predicted noise from text-conditioned model.
+ guidance_rescale (float): Rescaling factor for guidance. Defaults to 0.0.
+
+ Returns:
+ torch.Tensor: Rescaled noise configuration.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ """Retrieve latents from an encoder output.
+
+ Args:
+ encoder_output (torch.Tensor): Output from an encoder (e.g., VAE).
+ generator (torch.Generator, optional): Random generator for sampling. Defaults to None.
+ sample_mode (str): Sampling mode ("sample" or "argmax"). Defaults to "sample".
+
+ Returns:
+ torch.Tensor: Retrieved latent tensor.
+ """
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FaithDiffStableDiffusionXLPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ unet_model = UNet2DConditionModel
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+ _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", "unet"]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "negative_add_time_ids",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: OriginalUNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.DDPMScheduler = DDPMScheduler.from_config(self.scheduler.config, subfolder="scheduler")
+ self.default_sample_size = self.unet.config.sample_size if unet is not None else 128
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = "cuda" # device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ dtype = text_encoders[0].dtype
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_encoder = text_encoder.to(dtype)
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_image_size(self, x, padder_size=8):
+ # 获取图像的宽高
+ width, height = x.size
+ padder_size = padder_size
+ # 计算需要填充的高度和宽度
+ mod_pad_h = (padder_size - height % padder_size) % padder_size
+ mod_pad_w = (padder_size - width % padder_size) % padder_size
+ x_np = np.array(x)
+ # 使用 ImageOps.expand 进行填充
+ x_padded = cv2.copyMakeBorder(
+ x_np, top=0, bottom=mod_pad_h, left=0, right=mod_pad_w, borderType=cv2.BORDER_REPLICATE
+ )
+
+ x = PIL.Image.fromarray(x_padded)
+ # x = x.resize((width + mod_pad_w, height + mod_pad_h))
+
+ return x, width, height, width + mod_pad_w, height + mod_pad_h
+
+ def check_inputs(
+ self,
+ lr_img,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if lr_img is None:
+ raise ValueError("`lr_image` must be provided!")
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.FloatTensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ def set_encoder_tile_settings(
+ self,
+ denoise_encoder_tile_sample_min_size=1024,
+ denoise_encoder_sample_overlap_factor=0.25,
+ vae_sample_size=1024,
+ vae_tile_overlap_factor=0.25,
+ ):
+ self.unet.denoise_encoder.tile_sample_min_size = denoise_encoder_tile_sample_min_size
+ self.unet.denoise_encoder.tile_overlap_factor = denoise_encoder_sample_overlap_factor
+ self.vae.config.sample_size = vae_sample_size
+ self.vae.tile_overlap_factor = vae_tile_overlap_factor
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+ self.unet.denoise_encoder.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+ self.unet.denoise_encoder.disable_tiling()
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ def prepare_image_latents(
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
+ ):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ image_latents = image
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ # needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ # if needs_upcasting:
+ # image = image.float()
+ # self.upcast_vae()
+ self.unet.denoise_encoder.to(device=image.device, dtype=image.dtype)
+ image_latents = self.unet.denoise_encoder(image)
+ self.unet.denoise_encoder.to("cpu")
+ # cast back to fp16 if needed
+ # if needs_upcasting:
+ # self.vae.to(dtype=torch.float16)
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand image_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if do_classifier_free_guidance:
+ image_latents = image_latents
+
+ if image_latents.dtype != self.vae.dtype:
+ image_latents = image_latents.to(dtype=self.vae.dtype)
+
+ return image_latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ lr_img: PipelineImageInput = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ start_point: Optional[str] = "noise",
+ timesteps: List[int] = None,
+ denoising_end: Optional[float] = None,
+ overlap: float = 0.5,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ add_sample: bool = True,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ lr_img (PipelineImageInput, optional): Low-resolution input image for conditioning the generation process.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ start_point (str, *optional*):
+ The starting point for the generation process. Can be "noise" (random noise) or "lr" (low-resolution image).
+ Defaults to "noise".
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ overlap (float):
+ Overlap factor for local attention tiling (between 0.0 and 1.0). Controls the overlap between adjacent
+ grid patches during processing. Defaults to 0.5.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ add_sample (bool):
+ Whether to include sample conditioning (e.g., low-resolution image) in the UNet during denoising.
+ Defaults to True.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ lr_img,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._interrupt = False
+ self.tlc_vae_latents = LocalAttention((target_size[0] // 8, target_size[1] // 8), overlap)
+ self.tlc_vae_img = LocalAttention((target_size[0] // 8, target_size[1] // 8), overlap)
+
+ # 2. Define call parameters
+ batch_size = 1
+ num_images_per_prompt = 1
+
+ device = torch.device("cuda") # self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ num_samples = num_images_per_prompt
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ lora_scale=lora_scale,
+ )
+
+ lr_img_list = [lr_img]
+ lr_img = self.image_processor.preprocess(lr_img_list, height=height, width=width).to(
+ device, dtype=prompt_embeds.dtype
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ image_latents = self.prepare_image_latents(
+ lr_img, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, self.do_classifier_free_guidance
+ )
+
+ image_latents = self.tlc_vae_img.grids(image_latents)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ if start_point == "lr":
+ latents_condition_image = self.vae.encode(lr_img * 2 - 1).latent_dist.sample()
+ latents_condition_image = latents_condition_image * self.vae.config.scaling_factor
+ start_steps_tensor = torch.randint(999, 999 + 1, (latents.shape[0],), device=latents.device)
+ start_steps_tensor = start_steps_tensor.long()
+ latents = self.DDPMScheduler.add_noise(latents_condition_image[0:1, ...], latents, start_steps_tensor)
+
+ latents = self.tlc_vae_latents.grids(latents)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * image_latents.shape[0]
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 9. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+ sub_latents_num = latents.shape[0]
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if i >= 1:
+ latents = self.tlc_vae_latents.grids(latents).to(dtype=latents.dtype)
+ if self.interrupt:
+ continue
+ concat_grid = []
+ for sub_num in range(sub_latents_num):
+ self.scheduler.__dict__.update(views_scheduler_status[sub_num])
+ sub_latents = latents[sub_num, :, :, :].unsqueeze(0)
+ img_sub_latents = image_latents[sub_num, :, :, :].unsqueeze(0)
+ latent_model_input = (
+ torch.cat([sub_latents] * 2) if self.do_classifier_free_guidance else sub_latents
+ )
+ img_sub_latents = (
+ torch.cat([img_sub_latents] * 2) if self.do_classifier_free_guidance else img_sub_latents
+ )
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ pos_height = self.tlc_vae_latents.idxes[sub_num]["i"]
+ pos_width = self.tlc_vae_latents.idxes[sub_num]["j"]
+ add_time_ids = [
+ torch.tensor([original_size]),
+ torch.tensor([[pos_height, pos_width]]),
+ torch.tensor([target_size]),
+ ]
+ add_time_ids = torch.cat(add_time_ids, dim=1).to(
+ img_sub_latents.device, dtype=img_sub_latents.dtype
+ )
+ add_time_ids = add_time_ids.repeat(2, 1).to(dtype=img_sub_latents.dtype)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ with torch.amp.autocast(
+ device.type, dtype=latents.dtype, enabled=latents.dtype != self.unet.dtype
+ ):
+ noise_pred = self.unet(
+ scaled_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ input_embedding=img_sub_latents,
+ add_sample=add_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = sub_latents.dtype
+ sub_latents = self.scheduler.step(
+ noise_pred, t, sub_latents, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ views_scheduler_status[sub_num] = copy.deepcopy(self.scheduler.__dict__)
+ concat_grid.append(sub_latents)
+ if latents.dtype != sub_latents:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ sub_latents = sub_latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ latents = self.tlc_vae_latents.grids_inverse(torch.cat(concat_grid, dim=0)).to(sub_latents.dtype)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ elif latents.dtype != self.vae.dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ self.vae = self.vae.to(latents.dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py
index 68cb69115bde..9d6be763a0a0 100644
--- a/examples/community/pipeline_flux_differential_img2img.py
+++ b/examples/community/pipeline_flux_differential_img2img.py
@@ -87,7 +87,7 @@ def calculate_shift(
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
- max_shift: float = 1.16,
+ max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
@@ -221,13 +221,12 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=False,
do_convert_grayscale=True,
@@ -876,10 +875,10 @@ def __call__(
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
mu = calculate_shift(
image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py
new file mode 100644
index 000000000000..572856a047b2
--- /dev/null
+++ b/examples/community/pipeline_flux_rf_inversion.py
@@ -0,0 +1,1058 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# modeled after RF Inversion: https://rf-inversion.github.io/, authored by Litu Rout, Yujia Chen, Nataniel Ruiz,
+# Constantine Caramanis, Sanjay Shakkottai and Wen-Sheng Chu.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from diffusers.models.transformers import FluxTransformer2DModel
+from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> import requests
+ >>> import PIL
+ >>> from io import BytesIO
+ >>> from diffusers import DiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-dev",
+ ... torch_dtype=torch.bfloat16,
+ ... custom_pipeline="pipeline_flux_rf_inversion")
+ >>> pipe.to("cuda")
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
+ >>> image = download_image(img_url)
+
+ >>> inverted_latents, image_latents, latent_image_ids = pipe.invert(image=image, num_inversion_steps=28, gamma=0.5)
+
+ >>> edited_image = pipe(
+ ... prompt="a tomato",
+ ... inverted_latents=inverted_latents,
+ ... image_latents=image_latents,
+ ... latent_image_ids=latent_image_ids,
+ ... start_timestep=0,
+ ... stop_timestep=.25,
+ ... num_inference_steps=28,
+ ... eta=0.9,
+ ... ).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class RFInversionFluxPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ The Flux pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ @torch.no_grad()
+ # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image
+ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None):
+ image = self.image_processor.preprocess(
+ image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+ resized = self.image_processor.postprocess(image=image, output_type="pil")
+
+ if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
+ logger.warning(
+ "Your input images far exceed the default resolution of the underlying diffusion model. "
+ "The output images may contain severe artifacts! "
+ "Consider down-sampling the input using the `height` and `width` parameters"
+ )
+ image = image.to(dtype)
+
+ x0 = self.vae.encode(image.to(self._execution_device)).latent_dist.sample()
+ x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ x0 = x0.to(dtype)
+ return x0, resized
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ inverted_latents,
+ image_latents,
+ latent_image_ids,
+ height,
+ width,
+ start_timestep,
+ stop_timestep,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ if inverted_latents is not None and (image_latents is None or latent_image_ids is None):
+ raise ValueError(
+ "If `inverted_latents` are provided, `image_latents` and `latent_image_ids` also have to be passed. "
+ )
+ # check start_timestep and stop_timestep
+ if start_timestep < 0 or start_timestep > stop_timestep:
+ raise ValueError(f"`start_timestep` should be in [0, stop_timestep] but is {start_timestep}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents_inversion(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ image_latents,
+ ):
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength=1.0):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ sigmas = self.scheduler.sigmas[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, sigmas, num_inference_steps - t_start
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ inverted_latents: Optional[torch.FloatTensor] = None,
+ image_latents: Optional[torch.FloatTensor] = None,
+ latent_image_ids: Optional[torch.FloatTensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 1.0,
+ decay_eta: Optional[bool] = False,
+ eta_decay_power: Optional[float] = 1.0,
+ strength: float = 1.0,
+ start_timestep: float = 0,
+ stop_timestep: float = 0.25,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ inverted_latents (`torch.Tensor`, *optional*):
+ The inverted latents from `pipe.invert`.
+ image_latents (`torch.Tensor`, *optional*):
+ The image latents from `pipe.invert`.
+ latent_image_ids (`torch.Tensor`, *optional*):
+ The latent image ids from `pipe.invert`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ eta (`float`, *optional*, defaults to 1.0):
+ The controller guidance, balancing faithfulness & editability:
+ higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ inverted_latents,
+ image_latents,
+ latent_image_ids,
+ height,
+ width,
+ start_timestep,
+ stop_timestep,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+ do_rf_inversion = inverted_latents is not None
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if do_rf_inversion:
+ latents = inverted_latents
+ else:
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ if do_rf_inversion:
+ start_timestep = int(start_timestep * num_inference_steps)
+ stop_timestep = min(int(stop_timestep * num_inference_steps), num_inference_steps)
+ timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if do_rf_inversion:
+ y_0 = image_latents.clone()
+ # 6. Denoising loop / Controlled Reverse ODE, Algorithm 2 from: https://arxiv.org/pdf/2410.10792
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if do_rf_inversion:
+ # ti (current timestep) as annotated in algorithm 2 - i/num_inference_steps.
+ t_i = 1 - t / 1000
+ dt = torch.tensor(1 / (len(timesteps) - 1), device=device)
+
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ latents_dtype = latents.dtype
+ if do_rf_inversion:
+ v_t = -noise_pred
+ v_t_cond = (y_0 - latents) / (1 - t_i)
+ eta_t = eta if start_timestep <= i < stop_timestep else 0.0
+ if decay_eta:
+ eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop
+ v_hat_t = v_t + eta_t * (v_t_cond - v_t)
+
+ # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
+ latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])
+ else:
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
+
+ @torch.no_grad()
+ def invert(
+ self,
+ image: PipelineImageInput,
+ source_prompt: str = "",
+ source_guidance_scale=0.0,
+ num_inversion_steps: int = 28,
+ strength: float = 1.0,
+ gamma: float = 0.5,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ timesteps: List[int] = None,
+ dtype: Optional[torch.dtype] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Performs Algorithm 1: Controlled Forward ODE from https://arxiv.org/pdf/2410.10792
+ Args:
+ image (`PipelineImageInput`):
+ Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect
+ ratio.
+ source_prompt (`str` or `List[str]`, *optional* defaults to an empty prompt as done in the original paper):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ source_guidance_scale (`float`, *optional*, defaults to 0.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). For this algorithm, it's better to keep it 0.
+ num_inversion_steps (`int`, *optional*, defaults to 28):
+ The number of discretization steps.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ gamma (`float`, *optional*, defaults to 0.5):
+ The controller guidance for the forward ODE, balancing faithfulness & editability:
+ higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ """
+ dtype = dtype or self.text_encoder.dtype
+ batch_size = 1
+ self._joint_attention_kwargs = joint_attention_kwargs
+ num_channels_latents = self.transformer.config.in_channels // 4
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ device = self._execution_device
+
+ # 1. prepare image
+ image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype)
+ image_latents, latent_image_ids = self.prepare_latents_inversion(
+ batch_size, num_channels_latents, height, width, dtype, device, image_latents
+ )
+
+ # 2. prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps)
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inversion_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inversion_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength)
+
+ # 3. prepare text embeddings
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=source_prompt,
+ prompt_2=source_prompt,
+ device=device,
+ )
+ # 4. handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], source_guidance_scale, device=device, dtype=torch.float32)
+ else:
+ guidance = None
+
+ # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt
+ Y_t = image_latents
+ y_1 = torch.randn_like(Y_t)
+ N = len(sigmas)
+
+ # forward ODE loop
+ with self.progress_bar(total=N - 1) as progress_bar:
+ for i in range(N - 1):
+ t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device)
+ timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size)
+
+ # get the unconditional vector field
+ u_t_i = self.transformer(
+ hidden_states=Y_t,
+ timestep=timestep,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # get the conditional vector field
+ u_t_i_cond = (y_1 - Y_t) / (1 - t_i)
+
+ # controlled vector field
+ # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt
+ u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i)
+ Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1])
+ progress_bar.update()
+
+ # return the inverted latents (start point for the denoising loop), encoded image & latent image ids
+ return Y_t, image_latents, latent_image_ids
diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py
new file mode 100644
index 000000000000..919e0ad46bd1
--- /dev/null
+++ b/examples/community/pipeline_flux_semantic_guidance.py
@@ -0,0 +1,1351 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKL
+from diffusers.models.transformers import FluxTransformer2DModel
+from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ >>> "black-forest-labs/FLUX.1-dev",
+ >>> custom_pipeline="pipeline_flux_semantic_guidance",
+ >>> torch_dtype=torch.bfloat16
+ >>> )
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> image = pipe(
+ >>> prompt=prompt,
+ >>> num_inference_steps=28,
+ >>> guidance_scale=3.5,
+ >>> editing_prompt=["cat", "dog"], # changes from cat to dog.
+ >>> reverse_editing_direction=[True, False],
+ >>> edit_warmup_steps=[6, 8],
+ >>> edit_guidance_scale=[6, 6.5],
+ >>> edit_threshold=[0.89, 0.89],
+ >>> edit_cooldown_steps = [25, 27],
+ >>> edit_momentum_scale=0.3,
+ >>> edit_mom_beta=0.6,
+ >>> generator=torch.Generator(device="cuda").manual_seed(6543),
+ >>> ).images[0]
+ >>> image.save("semantic_flux.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxSemanticGuidancePipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Flux pipeline for text-to-image generation with semantic guidance.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def encode_text_with_editing(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ editing_prompt: Optional[List[str]] = None,
+ editing_prompt_2: Optional[List[str]] = None,
+ editing_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ """
+ Encode text prompts with editing prompts and negative prompts for semantic guidance.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide image generation.
+ prompt_2 (`str` or `List[str]`):
+ The prompt or prompts to guide image generation for second tokenizer.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ editing_prompt (`str` or `List[str]`, *optional*):
+ The editing prompts for semantic guidance.
+ editing_prompt_2 (`str` or `List[str]`, *optional*):
+ The editing prompts for semantic guidance for second tokenizer.
+ editing_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-computed embeddings for editing prompts.
+ pooled_editing_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-computed pooled embeddings for editing prompts.
+ device (`torch.device`, *optional*):
+ The device to use for computation.
+ num_images_per_prompt (`int`, defaults to 1):
+ Number of images to generate per prompt.
+ max_sequence_length (`int`, defaults to 512):
+ Maximum sequence length for text encoding.
+ lora_scale (`float`, *optional*):
+ Scale factor for LoRA layers if used.
+
+ Returns:
+ tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, int]:
+ A tuple containing the prompt embeddings, pooled prompt embeddings,
+ text IDs, and number of enabled editing prompts.
+ """
+ device = device or self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError("Prompt must be provided as string or list of strings")
+
+ # Get base prompt embeddings
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # Handle editing prompts
+ if editing_prompt_embeds is not None:
+ enabled_editing_prompts = int(editing_prompt_embeds.shape[0])
+ edit_text_ids = []
+ elif editing_prompt is not None:
+ editing_prompt_embeds = []
+ pooled_editing_prompt_embeds = []
+ edit_text_ids = []
+
+ editing_prompt_2 = editing_prompt if editing_prompt_2 is None else editing_prompt_2
+ for edit_1, edit_2 in zip(editing_prompt, editing_prompt_2):
+ e_prompt_embeds, pooled_embeds, e_ids = self.encode_prompt(
+ prompt=edit_1,
+ prompt_2=edit_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ editing_prompt_embeds.append(e_prompt_embeds)
+ pooled_editing_prompt_embeds.append(pooled_embeds)
+ edit_text_ids.append(e_ids)
+
+ enabled_editing_prompts = len(editing_prompt)
+
+ else:
+ edit_text_ids = []
+ enabled_editing_prompts = 0
+
+ if enabled_editing_prompts:
+ for idx in range(enabled_editing_prompts):
+ editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0)
+ pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0)
+
+ return (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ editing_prompt_embeds,
+ pooled_editing_prompt_embeds,
+ text_ids,
+ edit_text_ids,
+ enabled_editing_prompts,
+ )
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
+ ):
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ editing_prompt: Optional[Union[str, List[str]]] = None,
+ editing_prompt_2: Optional[Union[str, List[str]]] = None,
+ editing_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,
+ reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
+ edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
+ edit_warmup_steps: Optional[Union[int, List[int]]] = 8,
+ edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
+ edit_threshold: Optional[Union[float, List[float]]] = 0.9,
+ edit_momentum_scale: Optional[float] = 0.1,
+ edit_mom_beta: Optional[float] = 0.4,
+ edit_weights: Optional[List[float]] = None,
+ sem_guidance: Optional[List[torch.Tensor]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+ editing_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image editing. If not defined, no editing will be performed.
+ editing_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image editing. If not defined, will use editing_prompt instead.
+ editing_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings for editing. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, text embeddings will be generated from `editing_prompt` input argument.
+ reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
+ Whether to reverse the editing direction for each editing prompt.
+ edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
+ Guidance scale for the editing process. If provided as a list, each value corresponds to an editing prompt.
+ edit_warmup_steps (`int` or `List[int]`, *optional*, defaults to 10):
+ Number of warmup steps for editing guidance. If provided as a list, each value corresponds to an editing prompt.
+ edit_cooldown_steps (`int` or `List[int]`, *optional*, defaults to None):
+ Number of cooldown steps for editing guidance. If provided as a list, each value corresponds to an editing prompt.
+ edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
+ Threshold for editing guidance. If provided as a list, each value corresponds to an editing prompt.
+ edit_momentum_scale (`float`, *optional*, defaults to 0.1):
+ Scale of momentum to be added to the editing guidance at each diffusion step.
+ edit_mom_beta (`float`, *optional*, defaults to 0.4):
+ Beta value for momentum calculation in editing guidance.
+ edit_weights (`List[float]`, *optional*):
+ Weights for each editing prompt.
+ sem_guidance (`List[torch.Tensor]`, *optional*):
+ Pre-generated semantic guidance. If provided, it will be used instead of calculating guidance from editing prompts.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if editing_prompt:
+ enable_edit_guidance = True
+ if isinstance(editing_prompt, str):
+ editing_prompt = [editing_prompt]
+ enabled_editing_prompts = len(editing_prompt)
+ elif editing_prompt_embeds is not None:
+ enable_edit_guidance = True
+ enabled_editing_prompts = editing_prompt_embeds.shape[0]
+ else:
+ enabled_editing_prompts = 0
+ enable_edit_guidance = False
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ editing_prompts_embeds,
+ pooled_editing_prompt_embeds,
+ text_ids,
+ edit_text_ids,
+ enabled_editing_prompts,
+ ) = self.encode_text_with_editing(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ editing_prompt=editing_prompt,
+ editing_prompt_2=editing_prompt_2,
+ pooled_editing_prompt_embeds=pooled_editing_prompt_embeds,
+ lora_scale=lora_scale,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ _,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds] * batch_size, dim=0)
+ negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds] * batch_size, dim=0)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ edit_momentum = None
+ if edit_warmup_steps:
+ tmp_e_warmup_steps = edit_warmup_steps if isinstance(edit_warmup_steps, list) else [edit_warmup_steps]
+ min_edit_warmup_steps = min(tmp_e_warmup_steps)
+ else:
+ min_edit_warmup_steps = 0
+
+ if edit_cooldown_steps:
+ tmp_e_cooldown_steps = (
+ edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps]
+ )
+ max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps)
+ else:
+ max_edit_cooldown_steps = num_inference_steps
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.tensor([guidance_scale], device=device)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:
+ noise_pred_edit_concepts = []
+ for e_embed, pooled_e_embed, e_text_id in zip(
+ editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids
+ ):
+ noise_pred_edit = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_e_embed,
+ encoder_hidden_states=e_embed,
+ txt_ids=e_text_id,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred_edit_concepts.append(noise_pred_edit)
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ noise_pred_uncond = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_guidance = true_cfg_scale * (noise_pred - noise_pred_uncond)
+ else:
+ noise_pred_uncond = noise_pred
+ noise_guidance = noise_pred
+
+ if edit_momentum is None:
+ edit_momentum = torch.zeros_like(noise_guidance)
+
+ if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:
+ concept_weights = torch.zeros(
+ (enabled_editing_prompts, noise_guidance.shape[0]),
+ device=device,
+ dtype=noise_guidance.dtype,
+ )
+ noise_guidance_edit = torch.zeros(
+ (enabled_editing_prompts, *noise_guidance.shape),
+ device=device,
+ dtype=noise_guidance.dtype,
+ )
+
+ warmup_inds = []
+ for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
+ if isinstance(edit_guidance_scale, list):
+ edit_guidance_scale_c = edit_guidance_scale[c]
+ else:
+ edit_guidance_scale_c = edit_guidance_scale
+
+ if isinstance(edit_threshold, list):
+ edit_threshold_c = edit_threshold[c]
+ else:
+ edit_threshold_c = edit_threshold
+ if isinstance(reverse_editing_direction, list):
+ reverse_editing_direction_c = reverse_editing_direction[c]
+ else:
+ reverse_editing_direction_c = reverse_editing_direction
+ if edit_weights:
+ edit_weight_c = edit_weights[c]
+ else:
+ edit_weight_c = 1.0
+ if isinstance(edit_warmup_steps, list):
+ edit_warmup_steps_c = edit_warmup_steps[c]
+ else:
+ edit_warmup_steps_c = edit_warmup_steps
+
+ if isinstance(edit_cooldown_steps, list):
+ edit_cooldown_steps_c = edit_cooldown_steps[c]
+ elif edit_cooldown_steps is None:
+ edit_cooldown_steps_c = i + 1
+ else:
+ edit_cooldown_steps_c = edit_cooldown_steps
+ if i >= edit_warmup_steps_c:
+ warmup_inds.append(c)
+ if i >= edit_cooldown_steps_c:
+ noise_guidance_edit[c, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
+ continue
+
+ if do_true_cfg:
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
+ else: # simple sega
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred
+ tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2))
+
+ tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
+ if reverse_editing_direction_c:
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
+ concept_weights[c, :] = tmp_weights
+
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
+
+ # torch.quantile function expects float32
+ if noise_guidance_edit_tmp.dtype == torch.float32:
+ tmp = torch.quantile(
+ torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2),
+ edit_threshold_c,
+ dim=2,
+ keepdim=False,
+ )
+ else:
+ tmp = torch.quantile(
+ torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32),
+ edit_threshold_c,
+ dim=2,
+ keepdim=False,
+ ).to(noise_guidance_edit_tmp.dtype)
+
+ noise_guidance_edit_tmp = torch.where(
+ torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None],
+ noise_guidance_edit_tmp,
+ torch.zeros_like(noise_guidance_edit_tmp),
+ )
+
+ noise_guidance_edit[c, :, :, :] = noise_guidance_edit_tmp
+
+ warmup_inds = torch.tensor(warmup_inds).to(device)
+ if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
+ concept_weights = concept_weights.to("cpu") # Offload to cpu
+ noise_guidance_edit = noise_guidance_edit.to("cpu")
+
+ concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)
+ concept_weights_tmp = torch.where(
+ concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
+ )
+ concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
+
+ noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
+ noise_guidance_edit_tmp = torch.einsum(
+ "cb,cbij->bij", concept_weights_tmp, noise_guidance_edit_tmp
+ )
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp
+ noise_guidance = noise_guidance + noise_guidance_edit_tmp
+
+ del noise_guidance_edit_tmp
+ del concept_weights_tmp
+ concept_weights = concept_weights.to(device)
+ noise_guidance_edit = noise_guidance_edit.to(device)
+
+ concept_weights = torch.where(
+ concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
+ )
+
+ concept_weights = torch.nan_to_num(concept_weights)
+
+ noise_guidance_edit = torch.einsum("cb,cbij->bij", concept_weights, noise_guidance_edit)
+
+ noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
+
+ edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit
+
+ if warmup_inds.shape[0] == len(noise_pred_edit_concepts):
+ noise_guidance = noise_guidance + noise_guidance_edit
+
+ if sem_guidance is not None:
+ edit_guidance = sem_guidance[i].to(device)
+ noise_guidance = noise_guidance + edit_guidance
+
+ if do_true_cfg:
+ noise_pred = noise_guidance + noise_pred_uncond
+ else:
+ noise_pred = noise_guidance
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(
+ image,
+ )
diff --git a/examples/community/pipeline_flux_with_cfg.py b/examples/community/pipeline_flux_with_cfg.py
index 06da6da899cd..f55f73620f45 100644
--- a/examples/community/pipeline_flux_with_cfg.py
+++ b/examples/community/pipeline_flux_with_cfg.py
@@ -64,12 +64,13 @@
"""
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
- max_shift: float = 1.16,
+ max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
@@ -189,9 +190,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -757,10 +756,10 @@ def __call__(
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py
index 3ece670e5bde..a294ff782450 100644
--- a/examples/community/pipeline_hunyuandit_differential_img2img.py
+++ b/examples/community/pipeline_hunyuandit_differential_img2img.py
@@ -327,9 +327,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
@@ -1008,6 +1006,8 @@ def __call__(
self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
+ device=device,
+ output_type="pt",
)
style = torch.tensor([0], device=device)
diff --git a/examples/community/pipeline_kolors_differential_img2img.py b/examples/community/pipeline_kolors_differential_img2img.py
index e5570248d22b..dfef872d1c30 100644
--- a/examples/community/pipeline_kolors_differential_img2img.py
+++ b/examples/community/pipeline_kolors_differential_img2img.py
@@ -209,16 +209,18 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
# Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt
def encode_prompt(
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index 508e84177928..736f00799eae 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -131,7 +131,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -145,7 +145,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -174,10 +174,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -205,7 +209,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py
index 8328bc2caed9..9377caf7ba2e 100644
--- a/examples/community/pipeline_sdxl_style_aligned.py
+++ b/examples/community/pipeline_sdxl_style_aligned.py
@@ -488,13 +488,17 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -628,7 +632,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -688,7 +694,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
index 8cee5ecbc141..50952304fc1e 100644
--- a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
+++ b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
@@ -207,7 +207,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
)
diff --git a/examples/community/pipeline_stable_diffusion_boxdiff.py b/examples/community/pipeline_stable_diffusion_boxdiff.py
index 6490c1400138..bd58a65ce787 100644
--- a/examples/community/pipeline_stable_diffusion_boxdiff.py
+++ b/examples/community/pipeline_stable_diffusion_boxdiff.py
@@ -417,7 +417,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -431,7 +431,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -460,10 +460,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -491,7 +495,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/pipeline_stable_diffusion_pag.py b/examples/community/pipeline_stable_diffusion_pag.py
index cea2c9735747..874303e0ad6c 100644
--- a/examples/community/pipeline_stable_diffusion_pag.py
+++ b/examples/community/pipeline_stable_diffusion_pag.py
@@ -384,7 +384,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -398,7 +398,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -427,10 +427,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -458,7 +462,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
index 1ac651a1fe60..8a709ab46757 100644
--- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
+++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
@@ -151,7 +151,7 @@ def __init__(
watermarker=watermarker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear")
# self.register_to_config(requires_safety_checker=requires_safety_checker)
self.register_to_config(max_noise_level=max_noise_level)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
new file mode 100644
index 000000000000..1269a69f0dc3
--- /dev/null
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -0,0 +1,2318 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from PIL import Image
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DDIMScheduler, DiffusionPipeline
+ >>> from diffusers.utils import load_image
+ >>> import torch.nn.functional as F
+ >>> from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+ >>> dtype = torch.float16
+ >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ >>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+
+ >>> pipeline = DiffusionPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+ ... scheduler=scheduler,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... torch_dtype=dtype,
+ ... ).to(device)
+
+
+ >>> def preprocess_image(image_path, device):
+ ... image = to_tensor((load_image(image_path)))
+ ... image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+ ... if image.shape[1] != 3:
+ ... image = image.expand(-1, 3, -1, -1)
+ ... image = F.interpolate(image, (1024, 1024))
+ ... image = image.to(dtype).to(device)
+ ... return image
+
+ >>> def preprocess_mask(mask_path, device):
+ ... mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+ ... mask = mask.unsqueeze_(0).float() # 0 or 1
+ ... mask = F.interpolate(mask, (1024, 1024))
+ ... mask = gaussian_blur(mask, kernel_size=(77, 77))
+ ... mask[mask < 0.1] = 0
+ ... mask[mask >= 0.1] = 1
+ ... mask = mask.to(dtype).to(device)
+ ... return mask
+
+ >>> prompt = "" # Set prompt to null
+ >>> seed=123
+ >>> generator = torch.Generator(device=device).manual_seed(seed)
+ >>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+ >>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+ >>> source_image = preprocess_image(source_image_path, device)
+ >>> mask = preprocess_mask(mask_path, device)
+
+ >>> image = pipeline(
+ ... prompt=prompt,
+ ... image=source_image,
+ ... mask_image=mask,
+ ... height=1024,
+ ... width=1024,
+ ... AAS=True, # enable AAS
+ ... strength=0.8, # inpainting strength
+ ... rm_guidance_scale=9, # removal guidance scale
+ ... ss_steps = 9, # similarity suppression steps
+ ... ss_scale = 0.3, # similarity suppression scale
+ ... AAS_start_step=0, # AAS start step
+ ... AAS_start_layer=34, # AAS start layer
+ ... AAS_end_layer=70, # AAS end layer
+ ... num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+ ... generator=generator,
+ ... guidance_scale=1,
+ ... ).images[0]
+ >>> image.save('./removed_img.png')
+ >>> print("Object removal completed")
+ ```
+"""
+
+
+class AttentionBase:
+ def __init__(self):
+ self.cur_step = 0
+ self.num_att_layers = -1
+ self.cur_att_layer = 0
+
+ def after_step(self):
+ pass
+
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+ self.cur_att_layer += 1
+ if self.cur_att_layer == self.num_att_layers:
+ self.cur_att_layer = 0
+ self.cur_step += 1
+ # after step
+ self.after_step()
+ return out
+
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=num_heads)
+ return out
+
+ def reset(self):
+ self.cur_step = 0
+ self.cur_att_layer = 0
+
+
+class AAS_XL(AttentionBase):
+ MODEL_TYPE = {"SD": 16, "SDXL": 70}
+
+ def __init__(
+ self,
+ start_step=4,
+ end_step=50,
+ start_layer=10,
+ end_layer=16,
+ layer_idx=None,
+ step_idx=None,
+ total_steps=50,
+ mask=None,
+ model_type="SD",
+ ss_steps=9,
+ ss_scale=1.0,
+ ):
+ """
+ Args:
+ start_step: the step to start AAS
+ start_layer: the layer to start AAS
+ layer_idx: list of the layers to apply AAS
+ step_idx: list the steps to apply AAS
+ total_steps: the total number of steps
+ mask: source mask with shape (h, w)
+ model_type: the model type, SD or SDXL
+ """
+ super().__init__()
+ self.total_steps = total_steps
+ self.total_layers = self.MODEL_TYPE.get(model_type, 16)
+ self.start_step = start_step
+ self.end_step = end_step
+ self.start_layer = start_layer
+ self.end_layer = end_layer
+ self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
+ self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+ self.mask = mask # mask with shape (1, 1 ,h, w)
+ self.ss_steps = ss_steps
+ self.ss_scale = ss_scale
+ self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
+ self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
+ self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
+ self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()
+
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):
+ B = q.shape[0] // num_heads
+ if is_mask_attn:
+ mask_flatten = mask.flatten(0)
+ if self.cur_step <= self.ss_steps:
+ # background
+ sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+ # object
+ sim_fg = self.ss_scale * sim
+ sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+ sim = torch.cat([sim_fg, sim_bg], dim=0)
+ else:
+ sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+ attn = sim.softmax(-1)
+ if len(attn) == 2 * len(v):
+ v = torch.cat([v] * 2)
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+ return out
+
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+ """
+ Attention forward function
+ """
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+ H = int(np.sqrt(q.shape[1]))
+ if H == 16:
+ mask = self.mask_16.to(sim.device)
+ elif H == 32:
+ mask = self.mask_32.to(sim.device)
+ elif H == 64:
+ mask = self.mask_64.to(sim.device)
+ else:
+ mask = self.mask_128.to(sim.device)
+
+ q_wo, q_w = q.chunk(2)
+ k_wo, k_w = k.chunk(2)
+ v_wo, v_w = v.chunk(2)
+ sim_wo, sim_w = sim.chunk(2)
+ attn_wo, attn_w = attn.chunk(2)
+
+ out_source = self.attn_batch(
+ q_wo,
+ k_wo,
+ v_wo,
+ sim_wo,
+ attn_wo,
+ is_cross,
+ place_in_unet,
+ num_heads,
+ is_mask_attn=False,
+ mask=None,
+ **kwargs,
+ )
+ out_target = self.attn_batch(
+ q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask=mask, **kwargs
+ )
+
+ if self.mask is not None:
+ if out_target.shape[0] == 2:
+ out_target_fg, out_target_bg = out_target.chunk(2, 0)
+ mask = mask.reshape(-1, 1) # (hw, 1)
+ out_target = out_target_fg * mask + out_target_bg * (1 - mask)
+ else:
+ out_target = out_target
+
+ out = torch.cat([out_source, out_target], dim=0)
+ return out
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask_pil_to_torch(mask, height, width)
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ # if image.min() < -1 or image.max() > 1:
+ # raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = mask_pil_to_torch(mask, height, width)
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ if image.shape[1] == 4:
+ # images are in latent space and thus can't
+ # be masked set masked_image to None
+ # we assume that the checkpoint is not an inpainting
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXL_AE_Pipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ FromSingleFileMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for object removal using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "add_neg_time_ids",
+ "mask",
+ "masked_image_latents",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ # mask = torch.nn.functional.interpolate(
+ # mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ # )
+ mask = torch.nn.functional.max_pool2d(mask, (8, 8)).round()
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ @property
+ def do_self_attention_redirection_guidance(self): # SARG
+ return self._rm_guidance_scale > 1 and self._AAS
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return (
+ self._guidance_scale > 1
+ and self.unet.config.time_cond_proj_dim is None
+ and not self.do_self_attention_redirection_guidance
+ ) # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def denoising_start(self):
+ return self._denoising_start
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ def image2latent(self, image: torch.Tensor, generator: torch.Generator):
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ if type(image) is Image:
+ image = np.array(image)
+ image = torch.from_numpy(image).float() / 127.5 - 1
+ image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
+ # input image density range [-1, 1]
+ # latents = self.vae.encode(image)['latent_dist'].mean
+ latents = self._encode_vae_image(image, generator)
+ # latents = retrieve_latents(self.vae.encode(image))
+ # latents = latents * self.vae.config.scaling_factor
+ return latents
+
+ def next_step(self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0.0, verbose=False):
+ """
+ Inverse sampling for DDIM Inversion
+ """
+ if verbose:
+ print("timestep: ", timestep)
+ next_step = timestep
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
+ return x_next, pred_x0
+
+ @torch.no_grad()
+ def invert(
+ self,
+ image: torch.Tensor,
+ prompt,
+ num_inference_steps=50,
+ eta=0.0,
+ original_size: Tuple[int, int] = None,
+ target_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ return_intermediates=False,
+ **kwds,
+ ):
+ """
+ invert a real image into noise map with determinisc DDIM inversion
+ """
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ batch_size = image.shape[0]
+ if isinstance(prompt, list):
+ if batch_size == 1:
+ image = image.expand(len(prompt), -1, -1, -1)
+ elif isinstance(prompt, str):
+ if batch_size > 1:
+ prompt = [prompt] * batch_size
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ prompt_2 = prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds_list.append(prompt_embeds)
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
+
+ # define initial latents
+ latents = self.image2latent(image, generator=None)
+
+ start_latents = latents
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = (height, width)
+ target_size = (height, width)
+ negative_original_size = original_size
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ add_time_ids = add_time_ids.repeat(batch_size, 1).to(DEVICE)
+
+ # interative sampling
+ self.scheduler.set_timesteps(num_inference_steps)
+ latents_list = [latents]
+ pred_x0_list = []
+ # for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+ for i, t in enumerate(reversed(self.scheduler.timesteps)):
+ model_inputs = latents
+
+ # predict the noise
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ model_inputs, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs
+ ).sample
+
+ # compute the previous noise sample x_t-1 -> x_t
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
+ """
+ if t >= 1 and t < 41:
+ latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
+ else:
+ latents, pred_x0 = self.next_step(noise_pred, t, latents) """
+
+ latents_list.append(latents)
+ pred_x0_list.append(pred_x0)
+
+ if return_intermediates:
+ # return the intermediate laters during inversion
+ # pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+ # latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+ return latents, latents_list, pred_x0_list
+ return latents, start_latents
+
+ def opt(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ x: torch.FloatTensor,
+ ):
+ """
+ predict the sampe the next step in the denoise process.
+ """
+ ref_noise = model_output[:1, :, :, :].expand(model_output.shape)
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+ x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t) ** 0.5 * ref_noise
+ return x_opt, pred_x0
+
+ def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
+ """
+ Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
+ """
+
+ def ca_forward(self, place_in_unet):
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+ """
+ The attention is similar to the original implementation of LDM CrossAttention class
+ except adding some modifications on the attention
+ """
+ if encoder_hidden_states is not None:
+ context = encoder_hidden_states
+ if attention_mask is not None:
+ mask = attention_mask
+
+ to_out = self.to_out
+ if isinstance(to_out, nn.modules.container.ModuleList):
+ to_out = self.to_out[0]
+ else:
+ to_out = self.to_out
+
+ h = self.heads
+ q = self.to_q(x)
+ is_cross = context is not None
+ context = context if is_cross else x
+ k = self.to_k(context)
+ v = self.to_v(context)
+ # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, "b n (h d) -> (b h) n d", h=h) for t in (q, k, v))
+
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, "b ... -> b (...)")
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
+ mask = mask[:, None, :].repeat(h, 1, 1)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ attn = sim.softmax(dim=-1)
+ # the only difference
+ out = editor(q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale)
+
+ return to_out(out)
+
+ return forward
+
+ def register_editor(net, count, place_in_unet):
+ for name, subnet in net.named_children():
+ if net.__class__.__name__ == "Attention": # spatial Transformer layer
+ net.forward = ca_forward(net, place_in_unet)
+ return count + 1
+ elif hasattr(net, "children"):
+ count = register_editor(subnet, count, place_in_unet)
+ return count
+
+ cross_att_count = 0
+ for net_name, net in unet.named_children():
+ if "down" in net_name:
+ cross_att_count += register_editor(net, 0, "down")
+ elif "mid" in net_name:
+ cross_att_count += register_editor(net, 0, "mid")
+ elif "up" in net_name:
+ cross_att_count += register_editor(net, 0, "up")
+ editor.num_att_layers = cross_att_count
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.9999,
+ AAS: bool = True, # AE parameter
+ rm_guidance_scale: float = 7.0, # AE parameter
+ ss_steps: int = 9, # AE parameter
+ ss_scale: float = 0.3, # AE parameter
+ AAS_start_step: int = 0, # AE parameter
+ AAS_start_layer: int = 34, # AE parameter
+ AAS_end_layer: int = 70, # AE parameter
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+ `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+ contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+ the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+ and contain information inreleant for inpainging, such as background.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+ integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+ if `do_classifier_free_guidance` is set to `True`.
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+
+ ########### AE parameters
+ self._num_timesteps = num_inference_steps
+ self._rm_guidance_scale = rm_guidance_scale
+ self._AAS = AAS
+ self._ss_steps = ss_steps
+ self._ss_scale = ss_scale
+ self._AAS_start_step = AAS_start_step
+ self._AAS_start_layer = AAS_start_layer
+ self._AAS_end_layer = AAS_end_layer
+ ###########
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ mask = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif init_image.shape[1] == 4:
+ # if images are in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = init_image * (mask < 0.5)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = True if self.denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 10. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ ###########
+ if self.do_self_attention_redirection_guidance:
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+ ############
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # apply AAS to modify the attention module
+ if self.do_self_attention_redirection_guidance:
+ self._AAS_end_step = int(strength * self._num_timesteps)
+ layer_idx = list(range(self._AAS_start_layer, self._AAS_end_layer))
+ editor = AAS_XL(
+ self._AAS_start_step,
+ self._AAS_end_step,
+ self._AAS_start_layer,
+ self._AAS_end_layer,
+ layer_idx=layer_idx,
+ mask=mask_image,
+ model_type="SDXL",
+ ss_steps=self._ss_steps,
+ ss_scale=self._ss_scale,
+ )
+ self.regiter_attention_editor_diffusers(self.unet, editor)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 11.1 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # removal guidance
+ latent_model_input = (
+ torch.cat([latents] * 2) if self.do_self_attention_redirection_guidance else latents
+ ) # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+ # latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform SARG
+ if self.do_self_attention_redirection_guidance:
+ noise_pred_wo, noise_pred_w = noise_pred.chunk(2)
+ delta = noise_pred_w - noise_pred_wo
+ noise_pred = noise_pred_wo + self._rm_guidance_scale * delta
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ latents = latents[-1:]
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ return StableDiffusionXLPipelineOutput(images=latents)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if padding_mask_crop is not None:
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
index ae495979f366..e55be92962f2 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
@@ -226,12 +226,16 @@ def __init__(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
@@ -359,7 +363,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -419,7 +425,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index 94ca71cf7b1b..8480117866cc 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -374,12 +374,16 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
@@ -507,7 +511,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -567,7 +573,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
index 584820e86254..e74ea263017f 100644
--- a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
+++ b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
@@ -258,7 +258,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -394,7 +394,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -454,7 +456,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stable_diffusion_xl_ipex.py b/examples/community/pipeline_stable_diffusion_xl_ipex.py
index 022dfb1abf82..f43726b1b5b8 100644
--- a/examples/community/pipeline_stable_diffusion_xl_ipex.py
+++ b/examples/community/pipeline_stable_diffusion_xl_ipex.py
@@ -253,10 +253,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -390,7 +394,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -450,7 +456,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stg_cogvideox.py b/examples/community/pipeline_stg_cogvideox.py
new file mode 100644
index 000000000000..2e7f7906a36a
--- /dev/null
+++ b/examples/community/pipeline_stg_cogvideox.py
@@ -0,0 +1,876 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+import types
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.loaders import CogVideoXLoraLoaderMixin
+from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from examples.community.pipeline_stg_cogvideox import CogVideoXSTGPipeline
+
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+ >>> pipe = CogVideoXSTGPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16).to("cuda")
+ >>> prompt = (
+ ... "A father and son building a treehouse together, their hands covered in sawdust and smiles on their faces, realistic style."
+ ... )
+ >>> pipe.transformer.to(memory_format=torch.channels_last)
+
+ >>> # Configure STG mode options
+ >>> stg_applied_layers_idx = [11] # Layer indices from 0 to 41
+ >>> stg_scale = 1.0 # Set to 0.0 for CFG
+ >>> do_rescaling = False
+
+ >>> # Generate video frames with STG parameters
+ >>> frames = pipe(
+ ... prompt=prompt,
+ ... stg_applied_layers_idx=stg_applied_layers_idx,
+ ... stg_scale=stg_scale,
+ ... do_rescaling=do_rescaling,
+ >>> ).frames[0]
+ >>> export_to_video(frames, "output.mp4", fps=8)
+ ```
+"""
+
+
+def forward_with_stg(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+) -> torch.Tensor:
+ hidden_states_ptb = hidden_states[2:]
+ encoder_hidden_states_ptb = encoder_hidden_states[2:]
+
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ hidden_states[2:] = hidden_states_ptb
+ encoder_hidden_states[2:] = encoder_hidden_states_ptb
+
+ return hidden_states, encoder_hidden_states
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using CogVideoX.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. CogVideoX uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CogVideoXTransformer3DModel`]):
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = (
+ batch_size,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae_scaling_factor_image * latents
+
+ frames = self.vae.decode(latents).sample
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ p = self.transformer.config.patch_size
+ p_t = self.transformer.config.patch_size_t
+
+ base_size_width = self.transformer.config.sample_width // p
+ base_size_height = self.transformer.config.sample_height // p
+
+ if p_t is None:
+ # CogVideoX 1.0
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ device=device,
+ )
+ else:
+ # CogVideoX 1.5
+ base_num_frames = (num_frames + p_t - 1) // p_t
+
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=None,
+ grid_size=(grid_height, grid_width),
+ temporal_size=base_num_frames,
+ grid_type="slice",
+ max_size=(base_size_height, base_size_width),
+ device=device,
+ )
+
+ return freqs_cos, freqs_sin
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_spatio_temporal_guidance(self):
+ return self._stg_scale > 0.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ stg_applied_layers_idx: Optional[List[int]] = [11],
+ stg_scale: Optional[float] = 0.0,
+ do_rescaling: Optional[bool] = False,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
+ num_frames (`int`, defaults to `48`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_frames = num_frames or self.transformer.config.sample_frames
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._stg_scale = stg_scale
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ if self.do_spatio_temporal_guidance:
+ for i in stg_applied_layers_idx:
+ self.transformer.transformer_blocks[i].forward = types.MethodType(
+ forward_with_stg, self.transformer.transformer_blocks[i]
+ )
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
+ patch_size_t = self.transformer.config.patch_size_t
+ additional_frames = 0
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
+ additional_frames = patch_size_t - latent_frames % patch_size_t
+ num_frames += additional_frames * self.vae_scale_factor_temporal
+
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ latent_model_input = torch.cat([latents] * 2)
+ elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ latent_model_input = torch.cat([latents] * 3)
+ else:
+ latent_model_input = latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ + self._stg_scale * (noise_pred_text - noise_pred_perturb)
+ )
+
+ if do_rescaling:
+ rescaling_scale = 0.7
+ factor = noise_pred_text.std() / noise_pred.std()
+ factor = rescaling_scale * factor + (1 - rescaling_scale)
+ noise_pred = noise_pred * factor
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ # Discard any padding frames that were added for CogVideoX 1.5
+ latents = latents[:, additional_frames:]
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(frames=video)
diff --git a/examples/community/pipeline_stg_hunyuan_video.py b/examples/community/pipeline_stg_hunyuan_video.py
new file mode 100644
index 000000000000..e41f99e13a22
--- /dev/null
+++ b/examples/community/pipeline_stg_hunyuan_video.py
@@ -0,0 +1,794 @@
+# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import types
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.loaders import HunyuanVideoLoraLoaderMixin
+from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
+from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from diffusers import HunyuanVideoTransformer3DModel
+ >>> from examples.community.pipeline_stg_hunyuan_video import HunyuanVideoSTGPipeline
+
+ >>> model_id = "hunyuanvideo-community/HunyuanVideo"
+ >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe = HunyuanVideoSTGPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
+ >>> pipe.vae.enable_tiling()
+ >>> pipe.to("cuda")
+
+ >>> # Configure STG mode options
+ >>> stg_applied_layers_idx = [2] # Layer indices from 0 to 41
+ >>> stg_scale = 1.0 # Set 0.0 for CFG
+
+ >>> output = pipe(
+ ... prompt="A wolf howling at the moon, with the moon subtly resembling a giant clock face, realistic style.",
+ ... height=320,
+ ... width=512,
+ ... num_frames=61,
+ ... num_inference_steps=30,
+ ... stg_applied_layers_idx=stg_applied_layers_idx,
+ ... stg_scale=stg_scale,
+ >>> ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=15)
+ ```
+"""
+
+
+DEFAULT_PROMPT_TEMPLATE = {
+ "template": (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ ),
+ "crop_start": 95,
+}
+
+
+def forward_with_stg(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ return hidden_states, encoder_hidden_states
+
+
+def forward_without_stg(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=freqs_cis,
+ )
+
+ # 3. Modulation and residual connection
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # 4. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return hidden_states, encoder_hidden_states
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using HunyuanVideo.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`LlamaModel`]):
+ [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ tokenizer (`LlamaTokenizer`):
+ Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ transformer ([`HunyuanVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder_2 ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ text_encoder: LlamaModel,
+ tokenizer: LlamaTokenizerFast,
+ transformer: HunyuanVideoTransformer3DModel,
+ vae: AutoencoderKLHunyuanVideo,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_llama_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_template: Dict[str, Any],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ num_hidden_layers_to_skip: int = 2,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ prompt = [prompt_template["template"].format(p) for p in prompt]
+
+ crop_start = prompt_template.get("crop_start", None)
+ if crop_start is None:
+ prompt_template_input = self.tokenizer(
+ prompt_template["template"],
+ padding="max_length",
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=False,
+ )
+ crop_start = prompt_template_input["input_ids"].shape[-1]
+ # Remove <|eot_id|> token and placeholder {}
+ crop_start -= 2
+
+ max_sequence_length += crop_start
+ text_inputs = self.tokenizer(
+ prompt,
+ max_length=max_sequence_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ )
+ text_input_ids = text_inputs.input_ids.to(device=device)
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
+
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
+
+ if crop_start is not None and crop_start > 0:
+ prompt_embeds = prompt_embeds[:, crop_start:]
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 77,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]] = None,
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ ):
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
+ prompt,
+ prompt_template,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if pooled_prompt_embeds is None:
+ if prompt_2 is None and pooled_prompt_embeds is None:
+ prompt_2 = prompt
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=77,
+ )
+
+ return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_template=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_template is not None:
+ if not isinstance(prompt_template, dict):
+ raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
+ if "template" not in prompt_template:
+ raise ValueError(
+ f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
+ )
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: 32,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_spatio_temporal_guidance(self):
+ return self._stg_scale > 0.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Union[str, List[str]] = None,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ num_inference_steps: int = 50,
+ sigmas: List[float] = None,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ max_sequence_length: int = 256,
+ stg_applied_layers_idx: Optional[List[int]] = [2],
+ stg_scale: Optional[float] = 0.0,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ height (`int`, defaults to `720`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `129`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
+ CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
+ not applied.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ prompt_template,
+ )
+
+ self._stg_scale = stg_scale
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
+ if pooled_prompt_embeds is not None:
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ )
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_latent_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare guidance condition
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if self.do_spatio_temporal_guidance:
+ for i in stg_applied_layers_idx:
+ self.transformer.transformer_blocks[i].forward = types.MethodType(
+ forward_without_stg, self.transformer.transformer_blocks[i]
+ )
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_spatio_temporal_guidance:
+ for i in stg_applied_layers_idx:
+ self.transformer.transformer_blocks[i].forward = types.MethodType(
+ forward_with_stg, self.transformer.transformer_blocks[i]
+ )
+
+ noise_pred_perturb = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred + self._stg_scale * (noise_pred - noise_pred_perturb)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HunyuanVideoPipelineOutput(frames=video)
diff --git a/examples/community/pipeline_stg_ltx.py b/examples/community/pipeline_stg_ltx.py
new file mode 100644
index 000000000000..4a257a0a9278
--- /dev/null
+++ b/examples/community/pipeline_stg_ltx.py
@@ -0,0 +1,886 @@
+# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import types
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKLLTXVideo
+from diffusers.models.transformers import LTXVideoTransformer3DModel
+from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from examples.community.pipeline_stg_ltx import LTXSTGPipeline
+
+ >>> pipe = LTXSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage."
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> # Configure STG mode options
+ >>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
+ >>> stg_scale = 1.0 # Set 0.0 for CFG
+ >>> do_rescaling = False
+
+ >>> video = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=704,
+ ... height=480,
+ ... num_frames=161,
+ ... num_inference_steps=50,
+ ... stg_applied_layers_idx=stg_applied_layers_idx,
+ ... stg_scale=stg_scale,
+ ... do_rescaling=do_rescaling,
+ >>> ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=24)
+ ```
+"""
+
+
+def forward_with_stg(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ hidden_states_ptb = hidden_states[2:]
+ encoder_hidden_states_ptb = encoder_hidden_states[2:]
+
+ batch_size = hidden_states.size(0)
+ norm_hidden_states = self.norm1(hidden_states)
+
+ num_ada_params = self.scale_shift_table.shape[0]
+ ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+
+ attn_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + attn_hidden_states * gate_msa
+
+ attn_hidden_states = self.attn2(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_rotary_emb=None,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = hidden_states + attn_hidden_states
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + ff_output * gate_mlp
+
+ hidden_states[2:] = hidden_states_ptb
+ encoder_hidden_states[2:] = encoder_hidden_states_ptb
+
+ return hidden_states
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class LTXSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation.
+
+ Reference: https://github.com/Lightricks/LTX-Video
+
+ Args:
+ transformer ([`LTXVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLLTXVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTXVideo,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: LTXVideoTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
+ )
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ @staticmethod
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+ @staticmethod
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 128,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ height = height // self.vae_spatial_compression_ratio
+ width = width // self.vae_spatial_compression_ratio
+ num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def do_spatio_temporal_guidance(self):
+ return self._stg_scale > 0.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ frame_rate: int = 25,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ stg_applied_layers_idx: Optional[List[int]] = [19],
+ stg_scale: Optional[float] = 1.0,
+ do_rescaling: Optional[bool] = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `512`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, defaults to `704`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, defaults to `161`):
+ The number of video frames to generate
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, defaults to `3 `):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ decode_timestep (`float`, defaults to `0.0`):
+ The timestep at which generated video is decoded.
+ decode_noise_scale (`float`, defaults to `None`):
+ The interpolation factor between random noise and denoised latents at the decode timestep.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `128 `):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._stg_scale = stg_scale
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ if self.do_spatio_temporal_guidance:
+ for i in stg_applied_layers_idx:
+ self.transformer.transformer_blocks[i].forward = types.MethodType(
+ forward_with_stg, self.transformer.transformer_blocks[i]
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat(
+ [negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ video_sequence_length = latent_num_frames * latent_height * latent_width
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.16),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare micro-conditions
+ latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
+ rope_interpolation_scale = (
+ 1 / latent_frame_rate,
+ self.vae_spatial_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ )
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ latent_model_input = torch.cat([latents] * 2)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ latent_model_input = torch.cat([latents] * 3)
+ else:
+ latent_model_input = latents
+
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ + self._stg_scale * (noise_pred_text - noise_pred_perturb)
+ )
+
+ if do_rescaling:
+ rescaling_scale = 0.7
+ factor = noise_pred_text.std() / noise_pred.std()
+ factor = rescaling_scale * factor + (1 - rescaling_scale)
+ noise_pred = noise_pred * factor
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)
diff --git a/examples/community/pipeline_stg_ltx_image2video.py b/examples/community/pipeline_stg_ltx_image2video.py
new file mode 100644
index 000000000000..5a3c3c5304e3
--- /dev/null
+++ b/examples/community/pipeline_stg_ltx_image2video.py
@@ -0,0 +1,985 @@
+# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import types
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import PipelineImageInput
+from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
+from diffusers.models.autoencoders import AutoencoderKLLTXVideo
+from diffusers.models.transformers import LTXVideoTransformer3DModel
+from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import export_to_video, load_image
+ >>> from examples.community.pipeline_stg_ltx_image2video import LTXImageToVideoSTGPipeline
+
+ >>> pipe = LTXImageToVideoSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/11.png"
+ >>> )
+ >>> prompt = "A medieval fantasy scene featuring a rugged man with shoulder-length brown hair and a beard. He wears a dark leather tunic over a maroon shirt with intricate metal details. His facial expression is serious and intense, and he is making a gesture with his right hand, forming a small circle with his thumb and index finger. The warm golden lighting casts dramatic shadows on his face. The background includes an ornate stone arch and blurred medieval-style decor, creating an epic atmosphere."
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> # Configure STG mode options
+ >>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
+ >>> stg_scale = 1.0 # Set 0.0 for CFG
+ >>> do_rescaling = False
+
+ >>> video = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=704,
+ ... height=480,
+ ... num_frames=161,
+ ... num_inference_steps=50,
+ ... stg_applied_layers_idx=stg_applied_layers_idx,
+ ... stg_scale=stg_scale,
+ ... do_rescaling=do_rescaling,
+ >>> ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=24)
+ ```
+"""
+
+
+def forward_with_stg(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ hidden_states_ptb = hidden_states[2:]
+ encoder_hidden_states_ptb = encoder_hidden_states[2:]
+
+ batch_size = hidden_states.size(0)
+ norm_hidden_states = self.norm1(hidden_states)
+
+ num_ada_params = self.scale_shift_table.shape[0]
+ ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+
+ attn_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + attn_hidden_states * gate_msa
+
+ attn_hidden_states = self.attn2(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_rotary_emb=None,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = hidden_states + attn_hidden_states
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + ff_output * gate_mlp
+
+ hidden_states[2:] = hidden_states_ptb
+ encoder_hidden_states[2:] = encoder_hidden_states_ptb
+
+ return hidden_states
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class LTXImageToVideoSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation.
+
+ Reference: https://github.com/Lightricks/LTX-Video
+
+ Args:
+ transformer ([`LTXVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLLTXVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTXVideo,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: LTXVideoTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
+ )
+
+ self.default_height = 512
+ self.default_width = 704
+ self.default_frames = 121
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_channels_latents: int = 128,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ height = height // self.vae_spatial_compression_ratio
+ width = width // self.vae_spatial_compression_ratio
+ num_frames = (
+ (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
+ )
+
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
+ mask_shape = (batch_size, 1, num_frames, height, width)
+
+ if latents is not None:
+ conditioning_mask = latents.new_zeros(shape)
+ conditioning_mask[:, :, 0] = 1.0
+ conditioning_mask = self._pack_latents(
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ return latents.to(device=device, dtype=dtype), conditioning_mask
+
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i])
+ for i in range(batch_size)
+ ]
+ else:
+ init_latents = [
+ retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image
+ ]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+ init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
+ init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
+ conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
+ conditioning_mask[:, :, 0] = 1.0
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
+
+ conditioning_mask = self._pack_latents(
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ ).squeeze(-1)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+
+ return latents, conditioning_mask
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def do_spatio_temporal_guidance(self):
+ return self._stg_scale > 0.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ frame_rate: int = 25,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ stg_applied_layers_idx: Optional[List[int]] = [19],
+ stg_scale: Optional[float] = 1.0,
+ do_rescaling: Optional[bool] = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `512`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, defaults to `704`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, defaults to `161`):
+ The number of video frames to generate
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, defaults to `3 `):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ decode_timestep (`float`, defaults to `0.0`):
+ The timestep at which generated video is decoded.
+ decode_noise_scale (`float`, defaults to `None`):
+ The interpolation factor between random noise and denoised latents at the decode timestep.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `128 `):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._stg_scale = stg_scale
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ if self.do_spatio_temporal_guidance:
+ for i in stg_applied_layers_idx:
+ self.transformer.transformer_blocks[i].forward = types.MethodType(
+ forward_with_stg, self.transformer.transformer_blocks[i]
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat(
+ [negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
+ )
+
+ # 4. Prepare latent variables
+ if latents is None:
+ image = self.video_processor.preprocess(image, height=height, width=width)
+ image = image.to(device=device, dtype=prompt_embeds.dtype)
+
+ num_channels_latents = self.transformer.config.in_channels
+ latents, conditioning_mask = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask, conditioning_mask])
+
+ # 5. Prepare timesteps
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ video_sequence_length = latent_num_frames * latent_height * latent_width
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.16),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare micro-conditions
+ latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
+ rope_interpolation_scale = (
+ 1 / latent_frame_rate,
+ self.vae_spatial_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ )
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ latent_model_input = torch.cat([latents] * 2)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ latent_model_input = torch.cat([latents] * 3)
+ else:
+ latent_model_input = latents
+
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+ timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ timestep, _ = timestep.chunk(2)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ + self._stg_scale * (noise_pred_text - noise_pred_perturb)
+ )
+ timestep, _, _ = timestep.chunk(3)
+
+ if do_rescaling:
+ rescaling_scale = 0.7
+ factor = noise_pred_text.std() / noise_pred.std()
+ factor = rescaling_scale * factor + (1 - rescaling_scale)
+ noise_pred = noise_pred * factor
+
+ # compute the previous noisy sample x_t -> x_t-1
+ noise_pred = self._unpack_latents(
+ noise_pred,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+
+ noise_pred = noise_pred[:, :, 1:]
+ noise_latents = latents[:, :, 1:]
+ pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
+
+ latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)
diff --git a/examples/community/pipeline_stg_mochi.py b/examples/community/pipeline_stg_mochi.py
new file mode 100644
index 000000000000..97b7293d0ae3
--- /dev/null
+++ b/examples/community/pipeline_stg_mochi.py
@@ -0,0 +1,843 @@
+# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import types
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.loaders import Mochi1LoraLoaderMixin
+from diffusers.models import AutoencoderKLMochi, MochiTransformer3DModel
+from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from examples.community.pipeline_stg_mochi import MochiSTGPipeline
+
+ >>> pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
+ >>> pipe.enable_model_cpu_offload()
+ >>> pipe.enable_vae_tiling()
+ >>> prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
+
+ >>> # Configure STG mode options
+ >>> stg_applied_layers_idx = [34] # Layer indices from 0 to 41
+ >>> stg_scale = 1.0 # Set 0.0 for CFG
+ >>> do_rescaling = False
+
+ >>> frames = pipe(
+ ... prompt=prompt,
+ ... num_inference_steps=28,
+ ... guidance_scale=3.5,
+ ... stg_applied_layers_idx=stg_applied_layers_idx,
+ ... stg_scale=stg_scale,
+ ... do_rescaling=do_rescaling).frames[0]
+ >>> export_to_video(frames, "mochi.mp4")
+ ```
+"""
+
+
+def forward_with_stg(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ hidden_states_ptb = hidden_states[2:]
+ encoder_hidden_states_ptb = encoder_hidden_states[2:]
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
+
+ if not self.context_pre_only:
+ norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
+ encoder_hidden_states, temb
+ )
+ else:
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
+
+ attn_hidden_states, context_attn_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=encoder_attention_mask,
+ )
+
+ hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
+ norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
+
+ if not self.context_pre_only:
+ encoder_hidden_states = encoder_hidden_states + self.norm2_context(
+ context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
+ )
+ norm_encoder_hidden_states = self.norm3_context(
+ encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
+ )
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + self.norm4_context(
+ context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
+ )
+
+ hidden_states[2:] = hidden_states_ptb
+ encoder_hidden_states[2:] = encoder_hidden_states_ptb
+
+ return hidden_states, encoder_hidden_states
+
+
+# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
+def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
+ if linear_steps is None:
+ linear_steps = num_steps // 2
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
+ quadratic_steps = num_steps - linear_steps
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
+ const = quadratic_coef * (linear_steps**2)
+ quadratic_sigma_schedule = [
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
+ ]
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
+ return sigma_schedule
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom value")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
+ r"""
+ The mochi pipeline for text-to-video generation.
+
+ Reference: https://github.com/genmoai/models
+
+ Args:
+ transformer ([`MochiTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLMochi`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLMochi,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: MochiTransformer3DModel,
+ force_zeros_for_empty_prompt: bool = False,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ # TODO: determine these scaling factors from model parameters
+ self.vae_spatial_scale_factor = 8
+ self.vae_temporal_scale_factor = 6
+ self.patch_size = 2
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
+ )
+ self.default_height = 480
+ self.default_width = 848
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ # The original Mochi implementation zeros out empty negative prompts
+ # but this can lead to overflow when placing the entire pipeline under the autocast context
+ # adding this here so that we can enable zeroing prompts if necessary
+ if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
+ text_input_ids = torch.zeros_like(text_input_ids, device=device)
+ prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = height // self.vae_spatial_scale_factor
+ width = width // self.vae_spatial_scale_factor
+ num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
+
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
+ latents = latents.to(dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def do_spatio_temporal_guidance(self):
+ return self._stg_scale > 0.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: int = 19,
+ num_inference_steps: int = 64,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.5,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ stg_applied_layers_idx: Optional[List[int]] = [34],
+ stg_scale: Optional[float] = 0.0,
+ do_rescaling: Optional[bool] = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to `self.default_height`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to `self.default_width`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, defaults to `19`):
+ The number of video frames to generate
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, defaults to `4.5`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `256`):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
+ is returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.default_height
+ width = width or self.default_width
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._stg_scale = stg_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ if self.do_spatio_temporal_guidance:
+ for i in stg_applied_layers_idx:
+ self.transformer.transformer_blocks[i].forward = types.MethodType(
+ forward_with_stg, self.transformer.transformer_blocks[i]
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat(
+ [negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
+ )
+
+ # 5. Prepare timestep
+ # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
+ threshold_noise = 0.025
+ sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
+ sigmas = np.array(sigmas)
+
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
+ # to make sure we're using the correct non-reversed timestep value.
+ self._current_timestep = 1000 - t
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ latent_model_input = torch.cat([latents] * 2)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ latent_model_input = torch.cat([latents] * 3)
+ else:
+ latent_model_input = latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ # Mochi CFG + Sampling runs in FP32
+ noise_pred = noise_pred.to(torch.float32)
+
+ if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
+ noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ + self._stg_scale * (noise_pred_text - noise_pred_perturb)
+ )
+
+ if do_rescaling:
+ rescaling_scale = 0.7
+ factor = noise_pred_text.std() / noise_pred.std()
+ factor = rescaling_scale * factor + (1 - rescaling_scale)
+ noise_pred = noise_pred * factor
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
+ latents = latents.to(latents_dtype)
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ video = latents
+ else:
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return MochiPipelineOutput(frames=video)
diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py
index 95bb37ce02b7..9a34f91bf841 100644
--- a/examples/community/pipeline_zero1to3.py
+++ b/examples/community/pipeline_zero1to3.py
@@ -108,7 +108,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -122,7 +122,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -151,10 +151,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -181,7 +185,7 @@ def __init__(
feature_extractor=feature_extractor,
cc_projection=cc_projection,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
# self.model_mode = None
diff --git a/examples/community/regional_prompting_stable_diffusion.py b/examples/community/regional_prompting_stable_diffusion.py
index 8a022987ba9d..9f09b4bd2bba 100644
--- a/examples/community/regional_prompting_stable_diffusion.py
+++ b/examples/community/regional_prompting_stable_diffusion.py
@@ -3,13 +3,12 @@
import torch
import torchvision.transforms.functional as FF
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import USE_PEFT_BACKEND
try:
@@ -17,6 +16,7 @@
except ImportError:
Compel = None
+KBASE = "ADDBASE"
KCOMM = "ADDCOMM"
KBRK = "BREAK"
@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
+ rp_args["power"]: int (power for attention maps in prompt mode)
+ rp_args["base_ratio"]:
+ float (Sets the ratio of the base prompt)
+ ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
+ [Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
Pipeline for text-to-image generation using Stable Diffusion.
@@ -70,6 +75,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__(
@@ -80,6 +86,7 @@ def __init__(
scheduler,
safety_checker,
feature_extractor,
+ image_encoder,
requires_safety_checker,
)
self.register_modules(
@@ -90,6 +97,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
)
@torch.no_grad()
@@ -110,17 +118,40 @@ def __call__(
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
+ use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
if negative_prompt is None:
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
device = self._execution_device
regions = 0
+ self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if isinstance(prompt, list) else [prompt]
- n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
+ n_prompts = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)
+
+ if use_base:
+ bases = prompts.copy()
+ n_bases = n_prompts.copy()
+
+ for i, prompt in enumerate(prompts):
+ parts = prompt.split(KBASE)
+ if len(parts) == 2:
+ bases[i], prompts[i] = parts
+ elif len(parts) > 2:
+ raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
+ for i, prompt in enumerate(n_prompts):
+ n_parts = prompt.split(KBASE)
+ if len(n_parts) == 2:
+ n_bases[i], n_prompts[i] = n_parts
+ elif len(n_parts) > 2:
+ raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")
+
+ all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
+ all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)
+
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
@@ -137,8 +168,16 @@ def getcompelembs(prps):
conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn)
- embs = getcompelembs(prompts)
- n_embs = getcompelembs(n_prompts)
+ base_embs = getcompelembs(all_bases_cn) if use_base else None
+ base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
+ # When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
+ embs = getcompelembs(prompts) if not use_base else base_embs
+ n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs
+
+ if use_base and self.base_ratio > 0:
+ conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
+ unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
+
prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
@@ -147,6 +186,18 @@ def getcompelembs(prps):
if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)
+
+ if use_base and self.base_ratio > 0:
+ base_embs = self.encode_prompt(bases, device, 1, True)[0]
+ base_n_embs = (
+ self.encode_prompt(n_bases, device, 1, True)[0]
+ if equal
+ else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
+ )
+
+ conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
+ unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
+
embs = n_embs = None
if not active:
@@ -225,8 +276,6 @@ def forward(
residual = hidden_states
- args = () if USE_PEFT_BACKEND else (scale,)
-
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -247,16 +296,15 @@ def forward(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- args = () if USE_PEFT_BACKEND else (scale,)
- query = attn.to_q(hidden_states, *args)
+ query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states, *args)
- value = attn.to_v(encoder_hidden_states, *args)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -283,7 +331,7 @@ def forward(
hidden_states = hidden_states.to(query.dtype)
# linear proj
- hidden_states = attn.to_out[0](hidden_states, *args)
+ hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
add = ""
if KCOMM in prompt:
add, prompt = prompt.split(KCOMM)
- add = add + " "
- prompts = prompt.split(KBRK)
- out_p.append([add + p for p in prompts])
+ add = add.strip() + " "
+ prompts = [p.strip() for p in prompt.split(KBRK)]
+ out_p.append([add + p for i, p in enumerate(prompts)])
out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
@@ -449,7 +497,6 @@ def startend(cells, array):
add = []
startend(add, inratios[1:])
icells.append(add)
-
return ocells, icells, sum(len(cell) for cell in icells)
diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py
index d9c616ab5ebc..7e66bff51d3b 100644
--- a/examples/community/rerender_a_video.py
+++ b/examples/community/rerender_a_video.py
@@ -30,10 +30,17 @@
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import BaseOutput, deprecate, logging
+from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -345,7 +352,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -625,7 +632,7 @@ def __call__(
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
- control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
+ control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.
strength ('float'): SDEdit strength.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -775,14 +782,14 @@ def __call__(
self.attn_state.reset()
# 4.1 prepare frames
- image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32)
+ image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype)
first_image = image[0] # C, H, W
# 4.2 Prepare controlnet_conditioning_image
# Currently we only support single control
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
- image=control_frames[0],
+ image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],
width=width,
height=height,
batch_size=batch_size,
@@ -901,6 +908,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
@@ -917,10 +927,10 @@ def __call__(
for idx in range(1, len(frames)):
image = frames[idx]
prev_image = frames[idx - 1]
- control_image = control_frames[idx]
+ control_image = control_frames(image) if callable(control_frames) else control_frames[idx]
# 5.1 prepare frames
- image = self.image_processor.preprocess(image).to(dtype=torch.float32)
- prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
+ image = self.image_processor.preprocess(image).to(dtype=self.dtype)
+ prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
self.flow_model, first_image, image[0], first_result, False, self.device
@@ -1100,6 +1110,9 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
return latents
if mask_start_t <= mask_end_t:
diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py
index c7c88d6fdcc7..6aa4067d695d 100644
--- a/examples/community/stable_diffusion_controlnet_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_img2img.py
@@ -179,7 +179,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py
index b473ffe79933..2d19e26b4220 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -278,7 +278,7 @@ def __init__(
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
index 8928f34239e3..4363a2294b63 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
@@ -263,7 +263,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py
index 123892f6229a..b2d4541797f5 100644
--- a/examples/community/stable_diffusion_ipex.py
+++ b/examples/community/stable_diffusion_ipex.py
@@ -105,7 +105,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -119,7 +119,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -148,10 +148,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -178,7 +182,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):
diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py
index 95b4b03e4de1..77e5011d2a70 100644
--- a/examples/community/stable_diffusion_mega.py
+++ b/examples/community/stable_diffusion_mega.py
@@ -66,7 +66,7 @@ def __init__(
requires_safety_checker: bool = True,
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py
index efb0fa89dbfc..9ef95a52051d 100644
--- a/examples/community/stable_diffusion_reference.py
+++ b/examples/community/stable_diffusion_reference.py
@@ -132,7 +132,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -146,7 +146,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -181,10 +181,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -202,7 +206,7 @@ def __init__(
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
- if unet.config.in_channels != 4:
+ if unet is not None and unet.config.in_channels != 4:
logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
@@ -219,7 +223,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/stable_diffusion_repaint.py b/examples/community/stable_diffusion_repaint.py
index 980e9a155997..0bc28eca15cc 100644
--- a/examples/community/stable_diffusion_repaint.py
+++ b/examples/community/stable_diffusion_repaint.py
@@ -187,7 +187,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -201,7 +201,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -236,10 +236,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -257,7 +261,7 @@ def __init__(
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
- if unet.config.in_channels != 4:
+ if unet is not None and unet.config.in_channels != 4:
logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
@@ -274,7 +278,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py
index 91540d1f4159..f2d184bb73e0 100755
--- a/examples/community/stable_diffusion_tensorrt_img2img.py
+++ b/examples/community/stable_diffusion_tensorrt_img2img.py
@@ -1,5 +1,5 @@
#
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -710,7 +710,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -724,7 +724,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -753,10 +753,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -806,7 +810,7 @@ def __init__(
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py
index b6f6711a53e7..8da37d37acbb 100755
--- a/examples/community/stable_diffusion_tensorrt_inpaint.py
+++ b/examples/community/stable_diffusion_tensorrt_inpaint.py
@@ -1,5 +1,5 @@
#
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -714,7 +714,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -728,7 +728,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -757,10 +757,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -810,7 +814,7 @@ def __init__(
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py
index f8761053ed1a..a3f9aae371b0 100755
--- a/examples/community/stable_diffusion_tensorrt_txt2img.py
+++ b/examples/community/stable_diffusion_tensorrt_txt2img.py
@@ -1,5 +1,5 @@
#
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
@@ -626,7 +626,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -640,7 +640,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -669,10 +669,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -722,7 +726,7 @@ def __init__(
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/stable_diffusion_xl_controlnet_reference.py b/examples/community/stable_diffusion_xl_controlnet_reference.py
new file mode 100644
index 000000000000..2c9bef311b0e
--- /dev/null
+++ b/examples/community/stable_diffusion_xl_controlnet_reference.py
@@ -0,0 +1,1368 @@
+# Based on stable_diffusion_xl_reference.py and stable_diffusion_controlnet_reference.py
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionXLControlNetPipeline
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import PipelineImageInput
+from diffusers.models import ControlNetModel
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import ControlNetModel, AutoencoderKL
+ >>> from diffusers.schedulers import UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> # download an image for the Canny controlnet
+ >>> canny_image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg"
+ ... )
+
+ >>> # download an image for the Reference controlnet
+ >>> ref_image = load_image(
+ ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
+ ... )
+
+ >>> # initialize the models and pipeline
+ >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
+ ... )
+ >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionXLControlNetReferencePipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
+ ... ).to("cuda:0")
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+ >>> # get canny image
+ >>> image = np.array(canny_image)
+ >>> image = cv2.Canny(image, 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+
+ >>> # generate image
+ >>> image = pipe(
+ ... prompt="a cat",
+ ... num_inference_steps=20,
+ ... controlnet_conditioning_scale=controlnet_conditioning_scale,
+ ... image=canny_image,
+ ... ref_image=ref_image,
+ ... reference_attn=True,
+ ... reference_adain=True
+ ... style_fidelity=1.0,
+ ... generator=torch.Generator("cuda").manual_seed(42)
+ ... ).images[0]
+ ```
+"""
+
+
+def torch_dfs(model: torch.nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
+ Second frozen text-encoder
+ ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
+ additional conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings should always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
+ watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
+ watermarker is used.
+ """
+
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage.to(device=device)
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.upcast_vae()
+ refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ if refimage.dtype != self.vae.dtype:
+ refimage = refimage.to(dtype=self.vae.dtype)
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ return ref_image_latents
+
+ def prepare_ref_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if not isinstance(image, torch.Tensor):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ images = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ image_ = np.array(image_)
+ image_ = image_[None, :]
+ images.append(image_)
+
+ image = images
+
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = (image - 0.5) / 0.5
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.stack(image, dim=0)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def check_ref_inputs(
+ self,
+ ref_image,
+ reference_guidance_start,
+ reference_guidance_end,
+ style_fidelity,
+ reference_attn,
+ reference_adain,
+ ):
+ ref_image_is_pil = isinstance(ref_image, PIL.Image.Image)
+ ref_image_is_tensor = isinstance(ref_image, torch.Tensor)
+
+ if not ref_image_is_pil and not ref_image_is_tensor:
+ raise TypeError(
+ f"ref image must be passed and be one of PIL image or torch tensor, but is {type(ref_image)}"
+ )
+
+ if not reference_attn and not reference_adain:
+ raise ValueError("`reference_attn` or `reference_adain` must be True.")
+
+ if style_fidelity < 0.0:
+ raise ValueError(f"style fidelity: {style_fidelity} can't be smaller than 0.")
+ if style_fidelity > 1.0:
+ raise ValueError(f"style fidelity: {style_fidelity} can't be larger than 1.0.")
+
+ if reference_guidance_start >= reference_guidance_end:
+ raise ValueError(
+ f"reference guidance start: {reference_guidance_start} cannot be larger or equal to reference guidance end: {reference_guidance_end}."
+ )
+ if reference_guidance_start < 0.0:
+ raise ValueError(f"reference guidance start: {reference_guidance_start} can't be smaller than 0.")
+ if reference_guidance_end > 1.0:
+ raise ValueError(f"reference guidance end: {reference_guidance_end} can't be larger than 1.0.")
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ ref_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ reference_guidance_start: float = 0.0,
+ reference_guidance_end: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ ref_image (`torch.Tensor`, `PIL.Image.Image`):
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
+ also be accepted as an image.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, pooled text embeddings are generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
+ argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ attention_auto_machine_weight (`float`):
+ Weight of using reference query for self attention's context.
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
+ gn_auto_machine_weight (`float`):
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
+ reference_guidance_start (`float`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the reference ControlNet starts applying.
+ reference_guidance_end (`float`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the reference ControlNet stops applying.
+ style_fidelity (`float`):
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
+ elif style_fidelity=0.0, prompt more important, else balanced.
+ reference_attn (`bool`):
+ Whether to use reference query for self attention's context.
+ reference_adain (`bool`):
+ Whether to use reference adain.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned containing the output images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ negative_pooled_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self.check_ref_inputs(
+ ref_image,
+ reference_guidance_start,
+ reference_guidance_end,
+ style_fidelity,
+ reference_attn,
+ reference_adain,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3.1 Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ prompt_2,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 3.2 Encode ip_adapter_image
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = image.shape[-2:]
+ elif isinstance(controlnet, MultiControlNetModel):
+ images = []
+
+ for image_ in image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+
+ image = images
+ height, width = image[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Preprocess reference image
+ ref_image = self.prepare_ref_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 6. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 7. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 8. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ reference_keeps = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+ reference_keep = 1.0 - float(
+ i / len(timesteps) < reference_guidance_start or (i + 1) / len(timesteps) > reference_guidance_end
+ )
+ reference_keeps.append(reference_keep)
+
+ # 9.2 Modify self attention and group norm
+ MODE = "write"
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .type_as(ref_image_latents)
+ .bool()
+ )
+
+ do_classifier_free_guidance = self.do_classifier_free_guidance
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
+ temb: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs
+ ):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 9.2 Prepare added time ids & embeddings
+ if isinstance(image, list):
+ original_size = original_size or image[0].shape[-2:]
+ else:
+ original_size = original_size or image.shape[-2:]
+ target_size = target_size or (height, width)
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 10.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ is_unet_compiled = is_compiled_module(self.unet)
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # Relevant thread:
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ torch._inductor.cudagraph_mark_step_begin()
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ # ref only part
+ if reference_keeps[i] > 0:
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py
index 107afc1f8b7a..e01eac970b58 100644
--- a/examples/community/stable_diffusion_xl_reference.py
+++ b/examples/community/stable_diffusion_xl_reference.py
@@ -1,5 +1,6 @@
# Based on stable_diffusion_reference.py
+import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -7,28 +8,33 @@
import torch
from diffusers import StableDiffusionXLPipeline
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.image_processor import PipelineImageInput
from diffusers.models.attention import BasicTransformerBlock
-from diffusers.models.unets.unet_2d_blocks import (
- CrossAttnDownBlock2D,
- CrossAttnUpBlock2D,
- DownBlock2D,
- UpBlock2D,
-)
-from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
-from diffusers.utils import PIL_INTERPOLATION, logging
+from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm # type: ignore
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
- >>> from diffusers import UniPCMultistepScheduler
+ >>> from diffusers.schedulers import UniPCMultistepScheduler
>>> from diffusers.utils import load_image
- >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+ >>> input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg")
>>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
@@ -38,7 +44,7 @@
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
>>> result_img = pipe(ref_image=input_image,
- prompt="1girl",
+ prompt="a dog",
num_inference_steps=20,
reference_attn=True,
reference_adain=True).images[0]
@@ -56,8 +62,6 @@ def torch_dfs(model: torch.nn.Module):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
-
-
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -72,33 +76,108 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
- def _default_height_width(self, height, width, image):
- # NOTE: It is possible that a list of images have different
- # dimensions for each image, so just checking the first image
- # is not _exactly_ correct, but it is simple.
- while isinstance(image, list):
- image = image[0]
-
- if height is None:
- if isinstance(image, PIL.Image.Image):
- height = image.height
- elif isinstance(image, torch.Tensor):
- height = image.shape[2]
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage.to(device=device)
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.upcast_vae()
+ refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ if refimage.dtype != self.vae.dtype:
+ refimage = refimage.to(dtype=self.vae.dtype)
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
- height = (height // 8) * 8 # round down to nearest multiple of 8
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
- if width is None:
- if isinstance(image, PIL.Image.Image):
- width = image.width
- elif isinstance(image, torch.Tensor):
- width = image.shape[3]
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
- width = (width // 8) * 8
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
- return height, width
+ return ref_image_latents
- def prepare_image(
+ def prepare_ref_image(
self,
image,
width,
@@ -151,41 +230,42 @@ def prepare_image(
return image
- def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
- refimage = refimage.to(device=device)
- if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
- self.upcast_vae()
- refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
- if refimage.dtype != self.vae.dtype:
- refimage = refimage.to(dtype=self.vae.dtype)
- # encode the mask image into latents space so we can concatenate it to the latents
- if isinstance(generator, list):
- ref_image_latents = [
- self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
- for i in range(batch_size)
- ]
- ref_image_latents = torch.cat(ref_image_latents, dim=0)
- else:
- ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
- ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+ def check_ref_inputs(
+ self,
+ ref_image,
+ reference_guidance_start,
+ reference_guidance_end,
+ style_fidelity,
+ reference_attn,
+ reference_adain,
+ ):
+ ref_image_is_pil = isinstance(ref_image, PIL.Image.Image)
+ ref_image_is_tensor = isinstance(ref_image, torch.Tensor)
- # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
- if ref_image_latents.shape[0] < batch_size:
- if not batch_size % ref_image_latents.shape[0] == 0:
- raise ValueError(
- "The passed images and the required batch size don't match. Images are supposed to be duplicated"
- f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
- " Make sure the number of images that you pass is divisible by the total requested batch size."
- )
- ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+ if not ref_image_is_pil and not ref_image_is_tensor:
+ raise TypeError(
+ f"ref image must be passed and be one of PIL image or torch tensor, but is {type(ref_image)}"
+ )
- ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
+ if not reference_attn and not reference_adain:
+ raise ValueError("`reference_attn` or `reference_adain` must be True.")
- # aligning device to prevent device errors when concating it with the latent model input
- ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
- return ref_image_latents
+ if style_fidelity < 0.0:
+ raise ValueError(f"style fidelity: {style_fidelity} can't be smaller than 0.")
+ if style_fidelity > 1.0:
+ raise ValueError(f"style fidelity: {style_fidelity} can't be larger than 1.0.")
+
+ if reference_guidance_start >= reference_guidance_end:
+ raise ValueError(
+ f"reference guidance start: {reference_guidance_start} cannot be larger or equal to reference guidance end: {reference_guidance_end}."
+ )
+ if reference_guidance_start < 0.0:
+ raise ValueError(f"reference guidance start: {reference_guidance_start} can't be smaller than 0.")
+ if reference_guidance_end > 1.0:
+ raise ValueError(f"reference guidance end: {reference_guidance_end} can't be larger than 1.0.")
@torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
@@ -194,6 +274,8 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -206,28 +288,220 @@ def __call__(
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
- callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
attention_auto_machine_weight: float = 1.0,
gn_auto_machine_weight: float = 1.0,
+ reference_guidance_start: float = 0.0,
+ reference_guidance_end: float = 1.0,
style_fidelity: float = 0.5,
reference_attn: bool = True,
reference_adain: bool = True,
+ **kwargs,
):
- assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ ref_image (`torch.Tensor`, `PIL.Image.Image`):
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
+ also be accepted as an image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ attention_auto_machine_weight (`float`):
+ Weight of using reference query for self attention's context.
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
+ gn_auto_machine_weight (`float`):
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
+ reference_guidance_start (`float`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the reference ControlNet starts applying.
+ reference_guidance_end (`float`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the reference ControlNet stops applying.
+ style_fidelity (`float`):
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
+ elif style_fidelity=0.0, prompt more important, else balanced.
+ reference_attn (`bool`):
+ Whether to use reference query for self attention's context.
+ reference_adain (`bool`):
+ Whether to use reference adain.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
- # 0. Default height and width to unet
- # height, width = self._default_height_width(height, width, ref_image)
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ # 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
+
original_size = original_size or (height, width)
target_size = target_size or (height, width)
@@ -244,8 +518,27 @@ def __call__(
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
)
+ self.check_ref_inputs(
+ ref_image,
+ reference_guidance_start,
+ reference_guidance_end,
+ style_fidelity,
+ reference_attn,
+ reference_adain,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._interrupt = False
+
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -256,15 +549,11 @@ def __call__(
device = self._execution_device
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
-
# 3. Encode input prompt
- text_encoder_lora_scale = (
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
+
(
prompt_embeds,
negative_prompt_embeds,
@@ -275,17 +564,19 @@ def __call__(
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
- do_classifier_free_guidance=do_classifier_free_guidance,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- lora_scale=text_encoder_lora_scale,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
)
+
# 4. Preprocess reference image
- ref_image = self.prepare_image(
+ ref_image = self.prepare_ref_image(
image=ref_image,
width=width,
height=height,
@@ -296,9 +587,9 @@ def __call__(
)
# 5. Prepare timesteps
- self.scheduler.set_timesteps(num_inference_steps, device=device)
-
- timesteps = self.scheduler.timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -312,6 +603,7 @@ def __call__(
generator,
latents,
)
+
# 7. Prepare reference latent variables
ref_image_latents = self.prepare_ref_latents(
ref_image,
@@ -319,13 +611,21 @@ def __call__(
prompt_embeds.dtype,
device,
generator,
- do_classifier_free_guidance,
+ self.do_classifier_free_guidance,
)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 9. Modify self attebtion and group norm
+ # 8.1 Create tensor stating which reference controlnets to keep
+ reference_keeps = []
+ for i in range(len(timesteps)):
+ reference_keep = 1.0 - float(
+ i / len(timesteps) < reference_guidance_start or (i + 1) / len(timesteps) > reference_guidance_end
+ )
+ reference_keeps.append(reference_keep)
+
+ # 8.2 Modify self attention and group norm
MODE = "write"
uc_mask = (
torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
@@ -333,6 +633,8 @@ def __call__(
.bool()
)
+ do_classifier_free_guidance = self.do_classifier_free_guidance
+
def hacked_basic_transformer_inner_forward(
self,
hidden_states: torch.Tensor,
@@ -604,7 +906,7 @@ def hacked_CrossAttnUpBlock2D_forward(
return hidden_states
def hacked_UpBlock2D_forward(
- self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs
):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
@@ -684,7 +986,7 @@ def hacked_UpBlock2D_forward(
module.var_bank = []
module.gn_weight *= 2
- # 10. Prepare added time ids & embeddings
+ # 9. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
@@ -698,62 +1000,101 @@ def hacked_UpBlock2D_forward(
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
- add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
- # 11. Denoising loop
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 10. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 10.1 Apply denoising_end
- if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- - (denoising_end * self.scheduler.config.num_train_timesteps)
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
+ # 11. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
# ref only part
- noise = randn_tensor(
- ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
- )
- ref_xt = self.scheduler.add_noise(
- ref_image_latents,
- noise,
- t.reshape(
- 1,
- ),
- )
- ref_xt = self.scheduler.scale_model_input(ref_xt, t)
-
- MODE = "write"
-
- self.unet(
- ref_xt,
- t,
- encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
- return_dict=False,
- )
+ if reference_keeps[i] > 0:
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )
# predict the noise residual
MODE = "read"
@@ -761,22 +1102,44 @@ def hacked_UpBlock2D_forward(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=cross_attention_kwargs,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
- if do_classifier_free_guidance and guidance_rescale > 0.0:
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -785,6 +1148,9 @@ def hacked_UpBlock2D_forward(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
@@ -792,25 +1158,43 @@ def hacked_UpBlock2D_forward(
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ elif latents.dtype != self.vae.dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ self.vae = self.vae.to(latents.dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
else:
image = latents
- return StableDiffusionXLPipelineOutput(images=image)
- # apply watermark if available
- if self.watermark is not None:
- image = self.watermark.apply_watermark(image)
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
- image = self.image_processor.postprocess(image, output_type=output_type)
+ image = self.image_processor.postprocess(image, output_type=output_type)
- # Offload last model to CPU
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
+ # Offload all models
+ self.maybe_free_model_hooks()
if not return_dict:
return (image,)
diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py
index c4378ab96f28..d73082b6cf38 100644
--- a/examples/community/text_inpainting.py
+++ b/examples/community/text_inpainting.py
@@ -71,7 +71,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -85,7 +85,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py
index c866ce2ae904..3c42c54f71f8 100644
--- a/examples/community/wildcard_stable_diffusion.py
+++ b/examples/community/wildcard_stable_diffusion.py
@@ -120,7 +120,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index 611026675daf..2045e7809310 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -73,7 +73,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
index 8090926974c4..38fe94ed3fe5 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -66,7 +66,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index fa7e7f1febee..fdb789c21628 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -79,7 +79,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
index 12d7db09a361..9a33f71ebac8 100644
--- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -72,7 +72,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
index cc5e6812127e..927e454d2b39 100644
--- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -78,7 +78,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/README_sd3.md b/examples/controlnet/README_sd3.md
index 7a7b4841125f..c95f34e32f38 100644
--- a/examples/controlnet/README_sd3.md
+++ b/examples/controlnet/README_sd3.md
@@ -1,6 +1,6 @@
-# ControlNet training example for Stable Diffusion 3 (SD3)
+# ControlNet training example for Stable Diffusion 3/3.5 (SD3/3.5)
-The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206).
+The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206) and [Stable Diffusion 3.5](https://stability.ai/news/introducing-stable-diffusion-3-5).
## Running locally with PyTorch
@@ -51,9 +51,9 @@ Please download the dataset and unzip it in the directory `fill50k` in the `exam
## Training
-First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium). We will use it as a base model for the ControlNet training.
+First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or the SD3.5 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). We will use it as a base model for the ControlNet training.
> [!NOTE]
-> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or [Stable Diffusion 3.5 Large Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login
@@ -90,6 +90,8 @@ accelerate launch train_controlnet_sd3.py \
--gradient_accumulation_steps=4
```
+To train a ControlNet model for Stable Diffusion 3.5, replace the `MODEL_DIR` with `stabilityai/stable-diffusion-3.5-medium`.
+
To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Our experiments were conducted on a single 40GB A100 GPU.
@@ -124,6 +126,8 @@ image = pipe(
image.save("./output.png")
```
+Similarly, for SD3.5, replace the `base_model_path` with `stabilityai/stable-diffusion-3.5-medium` and controlnet_path `DavyMorgan/sd35-controlnet-out'.
+
## Notes
### GPU usage
@@ -135,6 +139,8 @@ Make sure to use the right GPU when configuring the [accelerator](https://huggin
## Example results
+### SD3
+
#### After 500 steps with batch size 8
| | |
@@ -150,3 +156,20 @@ Make sure to use the right GPU when configuring the [accelerator](https://huggin
|| pale golden rod circle with old lace background |
 |  |
+### SD3.5
+
+#### After 500 steps with batch size 8
+
+| | |
+|-------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------:|
+|| pale golden rod circle with old lace background |
+  |  |
+
+
+#### After 3000 steps with batch size 8:
+
+| | |
+|-------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------:|
+|| pale golden rod circle with old lace background |
+  |  |
+
diff --git a/examples/controlnet/test_controlnet.py b/examples/controlnet/test_controlnet.py
index 3c508f80f1a4..d595a1a312b0 100644
--- a/examples/controlnet/test_controlnet.py
+++ b/examples/controlnet/test_controlnet.py
@@ -138,6 +138,27 @@ def test_controlnet_sd3(self):
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+class ControlNetSD35(ExamplesTestsAccelerate):
+ def test_controlnet_sd3(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet_sd3.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-sd35-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd35
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+
+
class ControlNetflux(ExamplesTestsAccelerate):
def test_controlnet_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index eaeb697c64c0..aa235ad65bfe 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -60,7 +60,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -152,9 +152,7 @@ def log_validation(
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
- formatted_images = []
-
- formatted_images.append(np.asarray(validation_image))
+ formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
@@ -571,9 +569,6 @@ def parse_args(input_args=None):
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
- if args.dataset_name is not None and args.train_data_dir is not None:
- raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
-
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
@@ -615,6 +610,7 @@ def make_train_dataset(args, tokenizer, accelerator):
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
)
else:
if args.train_data_dir is not None:
@@ -1145,7 +1141,7 @@ def load_model_hook(models, input_dir):
if global_step >= args.max_train_steps:
break
- # Create the pipeline using using the trained modules and save it.
+ # Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
controlnet = unwrap_model(controlnet)
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index 1aa9e881fca5..50af4ff8c39d 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -60,7 +60,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py
index 5969218f3c3e..a41615c7b546 100644
--- a/examples/controlnet/train_controlnet_flux.py
+++ b/examples/controlnet/train_controlnet_flux.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -65,7 +65,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -152,6 +152,7 @@ def log_validation(
guidance_scale=3.5,
generator=generator,
).images[0]
+ image = image.resize((args.resolution, args.resolution))
images.append(image)
image_logs.append(
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
@@ -165,9 +166,7 @@ def log_validation(
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
- formatted_images = []
-
- formatted_images.append(np.asarray(validation_image))
+ formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
@@ -1256,8 +1255,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
batch_size=pixel_latents_tmp.shape[0],
- height=pixel_latents_tmp.shape[2],
- width=pixel_latents_tmp.shape[3],
+ height=pixel_latents_tmp.shape[2] // 2,
+ width=pixel_latents_tmp.shape[3] // 2,
device=pixel_values.device,
dtype=pixel_values.dtype,
)
diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py
index dbe41578dc09..ffe460d72de8 100644
--- a/examples/controlnet/train_controlnet_sd3.py
+++ b/examples/controlnet/train_controlnet_sd3.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -59,7 +59,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.30.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -263,6 +263,12 @@ def parse_args(input_args=None):
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet.",
)
+ parser.add_argument(
+ "--num_extra_conditioning_channels",
+ type=int,
+ default=0,
+ help="Number of extra conditioning channels for controlnet.",
+ )
parser.add_argument(
"--revision",
type=str,
@@ -539,6 +545,9 @@ def parse_args(input_args=None):
default=77,
help="Maximum sequence length to use with with the T5 text encoder",
)
+ parser.add_argument(
+ "--dataset_preprocess_batch_size", type=int, default=1000, help="Batch size for preprocessing dataset."
+ )
parser.add_argument(
"--validation_prompt",
type=str,
@@ -986,7 +995,9 @@ def main(args):
controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
else:
logger.info("Initializing controlnet weights from transformer")
- controlnet = SD3ControlNetModel.from_transformer(transformer)
+ controlnet = SD3ControlNetModel.from_transformer(
+ transformer, num_extra_conditioning_channels=args.num_extra_conditioning_channels
+ )
transformer.requires_grad_(False)
vae.requires_grad_(False)
@@ -1123,7 +1134,12 @@ def compute_text_embeddings(batch, text_encoders, tokenizers):
# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
- train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
+ train_dataset = train_dataset.map(
+ compute_embeddings_fn,
+ batched=True,
+ batch_size=args.dataset_preprocess_batch_size,
+ new_fingerprint=new_fingerprint,
+ )
del text_encoder_one, text_encoder_two, text_encoder_three
del tokenizer_one, tokenizer_two, tokenizer_three
@@ -1267,8 +1283,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# Get the text embedding for conditioning
- prompt_embeds = batch["prompt_embeds"]
- pooled_prompt_embeds = batch["pooled_prompt_embeds"]
+ prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
# controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index ae627bb3a04c..17f313752989 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -157,9 +157,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
- formatted_images = []
-
- formatted_images.append(np.asarray(validation_image))
+ formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
@@ -598,9 +596,6 @@ def parse_args(input_args=None):
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
- if args.dataset_name is not None and args.train_data_dir is not None:
- raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
-
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
@@ -642,6 +637,7 @@ def get_train_dataset(args, accelerator):
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
)
else:
if args.train_data_dir is not None:
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index e498ca98b1c7..ea1449f9f382 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -63,7 +63,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -1334,7 +1334,9 @@ def main(args):
# run inference
if args.validation_prompt and args.num_validation_images > 0:
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = (
+ torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ )
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
for _ in range(args.num_validation_images)
diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md
index f97a4d0cd0f4..eed0575c322d 100644
--- a/examples/dreambooth/README.md
+++ b/examples/dreambooth/README.md
@@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \
## Stable Diffusion XL
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
+
+## Dataset
+
+We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
+
+The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
+
+We need to create a file `metadata.jsonl` in the directory with our images:
+
+```
+{"file_name": "01.jpg", "prompt": "prompt 01"}
+{"file_name": "02.jpg", "prompt": "prompt 02"}
+```
+
+If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
+
+```sh
+python convert_to_imagefolder.py --path my_dataset/
+```
+
+We use `--dataset_name` and `--caption_column` with training scripts.
+
+```
+--dataset_name=my_dataset/
+--caption_column=prompt
+```
diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md
index 69dfd241395b..c0802246e1f2 100644
--- a/examples/dreambooth/README_flux.md
+++ b/examples/dreambooth/README_flux.md
@@ -118,7 +118,7 @@ accelerate launch train_dreambooth_flux.py \
To better track our training experiments, we're using the following flags in the command above:
-* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
@@ -170,6 +170,21 @@ accelerate launch train_dreambooth_lora_flux.py \
--push_to_hub
```
+### Target Modules
+When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
+More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
+applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
+the exact modules for LoRA training. Here are some examples of target modules you can provide:
+- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
+- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
+- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
+> [!NOTE]
+> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
+> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
+> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
+> [!NOTE]
+> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
+
### Text Encoder Training
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
diff --git a/examples/dreambooth/README_lumina2.md b/examples/dreambooth/README_lumina2.md
new file mode 100644
index 000000000000..e466ec5a68e7
--- /dev/null
+++ b/examples/dreambooth/README_lumina2.md
@@ -0,0 +1,127 @@
+# DreamBooth training example for Lumina2
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
+
+The `train_dreambooth_lora_lumina2.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
+
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/dreambooth` folder and run
+```bash
+pip install -r requirements_sana.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
+
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
+
+Now, we can launch training using:
+
+```bash
+export MODEL_NAME="Alpha-VLLM/Lumina-Image-2.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-lumina2-lora"
+
+accelerate launch train_dreambooth_lora_lumina2.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="bf16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --use_8bit_adam \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+For using `push_to_hub`, make you're logged into your Hugging Face account:
+
+```bash
+huggingface-cli login
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+## Notes
+
+Additionally, we welcome you to explore the following CLI arguments:
+
+* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
+* `--system_prompt`: A custom system prompt to provide additional personality to the model.
+* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
+
+
+We provide several options for optimizing memory optimization:
+
+* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
+* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
+* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
+
+Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2) of the `LuminaPipeline` to know more about the model.
diff --git a/examples/dreambooth/README_sana.md b/examples/dreambooth/README_sana.md
new file mode 100644
index 000000000000..d82529c64de8
--- /dev/null
+++ b/examples/dreambooth/README_sana.md
@@ -0,0 +1,127 @@
+# DreamBooth training example for SANA
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
+
+The `train_dreambooth_lora_sana.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [SANA](https://arxiv.org/abs/2410.10629).
+
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/dreambooth` folder and run
+```bash
+pip install -r requirements_sana.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
+
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
+
+Now, we can launch training using:
+
+```bash
+export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="trained-sana-lora"
+
+accelerate launch train_dreambooth_lora_sana.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="bf16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --use_8bit_adam \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+For using `push_to_hub`, make you're logged into your Hugging Face account:
+
+```bash
+huggingface-cli login
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+## Notes
+
+Additionally, we welcome you to explore the following CLI arguments:
+
+* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
+* `--complex_human_instruction`: Instructions for complex human attention as shown in [here](https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55).
+* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
+
+
+We provide several options for optimizing memory optimization:
+
+* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
+* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
+* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
+
+Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md
index a340be350db8..2ac7bf7101d8 100644
--- a/examples/dreambooth/README_sd3.md
+++ b/examples/dreambooth/README_sd3.md
@@ -105,7 +105,7 @@ accelerate launch train_dreambooth_sd3.py \
To better track our training experiments, we're using the following flags in the command above:
-* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
--push_to_hub
```
+### Targeting Specific Blocks & Layers
+As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the
+transformer blocks (sometimes as little as two) can be enough to get great results.
+In some cases, it can be even better to maintain some of the blocks/layers frozen.
+
+For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93):
+> [!NOTE]
+> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more.
+> So, freezing other layers/targeting specific layers is a viable approach.
+> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps.
+> **Photorealism**
+> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening.
+> **Anatomy preservation**
+> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks.
+
+
+We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable.
+- with `--lora_blocks` you can specify the block numbers for training. E.g. passing -
+```diff
+--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37"
+```
+will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained.
+- with `--lora_layers` you can specify the types of layers you wish to train.
+By default, the trained layers are -
+`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v`
+If you wish to have a leaner LoRA / train more blocks over layers you could pass -
+```diff
++ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0
+```
+This will reduce LoRA size by roughly 50% for the same rank compared to the default.
+However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and
+freezing some of the early & blocks is usually better.
+
+
### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md
index 7a42bf8fffd8..565ff9a5dd33 100644
--- a/examples/dreambooth/README_sdxl.md
+++ b/examples/dreambooth/README_sdxl.md
@@ -99,7 +99,7 @@ accelerate launch train_dreambooth_lora_sdxl.py \
To better track our training experiments, we're using the following flags in the command above:
-* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Our experiments were conducted on a single 40GB A100 GPU.
diff --git a/examples/dreambooth/convert_to_imagefolder.py b/examples/dreambooth/convert_to_imagefolder.py
new file mode 100644
index 000000000000..333080077428
--- /dev/null
+++ b/examples/dreambooth/convert_to_imagefolder.py
@@ -0,0 +1,32 @@
+import argparse
+import json
+import pathlib
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--path",
+ type=str,
+ required=True,
+ help="Path to folder with image-text pairs.",
+)
+parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.")
+args = parser.parse_args()
+
+path = pathlib.Path(args.path)
+if not path.exists():
+ raise RuntimeError(f"`--path` '{args.path}' does not exist.")
+
+all_files = list(path.glob("*"))
+captions = list(path.glob("*.txt"))
+images = set(all_files) - set(captions)
+images = {image.stem: image for image in images}
+caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)}
+
+metadata = path.joinpath("metadata.jsonl")
+
+with metadata.open("w", encoding="utf-8") as f:
+ for caption, image in caption_image.items():
+ caption_text = caption.read_text(encoding="utf-8")
+ json.dump({"file_name": image.name, args.caption_column: caption_text}, f)
+ f.write("\n")
diff --git a/examples/dreambooth/requirements_sana.txt b/examples/dreambooth/requirements_sana.txt
new file mode 100644
index 000000000000..04b4bd6c29c0
--- /dev/null
+++ b/examples/dreambooth/requirements_sana.txt
@@ -0,0 +1,8 @@
+accelerate>=1.0.0
+torchvision
+transformers>=4.47.0
+ftfy
+tensorboard
+Jinja2
+peft>=0.14.0
+sentencepiece
\ No newline at end of file
diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py
index d197c8187b87..a76825e29448 100644
--- a/examples/dreambooth/test_dreambooth_lora_flux.py
+++ b/examples/dreambooth/test_dreambooth_lora_flux.py
@@ -37,6 +37,7 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
+ transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
@@ -136,6 +137,43 @@ def test_dreambooth_lora_latent_caching(self):
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
+ def test_dreambooth_lora_layers(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lora_layers {self.transformer_layer_type}
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names. In this test, we only params of
+ # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
+ starts_with_transformer = all(
+ key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_transformer)
+
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
diff --git a/examples/dreambooth/test_dreambooth_lora_lumina2.py b/examples/dreambooth/test_dreambooth_lora_lumina2.py
new file mode 100644
index 000000000000..1b729a0ff52e
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth_lora_lumina2.py
@@ -0,0 +1,206 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRAlumina2(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-lumina2-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_lora_lumina2.py"
+ transformer_layer_type = "layers.0.attn.to_k"
+
+ def test_dreambooth_lora_lumina2(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --resolution 32
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --max_sequence_length 16
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_latent_caching(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --resolution 32
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --max_sequence_length 16
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_layers(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --resolution 32
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lora_layers {self.transformer_layer_type}
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --max_sequence_length 16
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names. In this test, we only params of
+ # `self.transformer_layer_type` should be in the state dict.
+ starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_lumina2_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --resolution=32
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --max_sequence_length 16
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_lumina2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --resolution=32
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ --max_sequence_length 166
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --resolution=32
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --max_sequence_length 16
+ """.split()
+
+ resume_run_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/examples/dreambooth/test_dreambooth_lora_sana.py b/examples/dreambooth/test_dreambooth_lora_sana.py
new file mode 100644
index 000000000000..dfceb09a9736
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth_lora_sana.py
@@ -0,0 +1,206 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRASANA(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sana-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_lora_sana.py"
+ transformer_layer_type = "transformer_blocks.0.attn1.to_k"
+
+ def test_dreambooth_lora_sana(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --resolution 32
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --max_sequence_length 16
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_latent_caching(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --resolution 32
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --max_sequence_length 16
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_layers(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --resolution 32
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lora_layers {self.transformer_layer_type}
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --max_sequence_length 16
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names. In this test, we only params of
+ # `self.transformer_layer_type` should be in the state dict.
+ starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --resolution=32
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --max_sequence_length 16
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --resolution=32
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ --max_sequence_length 166
+ """.split()
+
+ test_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --resolution=32
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ --max_sequence_length 16
+ """.split()
+
+ resume_run_args.extend(["--instance_prompt", ""])
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py
index ec323be4143e..5d6c8bb9938a 100644
--- a/examples/dreambooth/test_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/test_dreambooth_lora_sd3.py
@@ -38,6 +38,9 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
+ transformer_block_idx = 0
+ layer_type = "attn.to_k"
+
def test_dreambooth_lora_sd3(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
@@ -136,6 +139,74 @@ def test_dreambooth_lora_latent_caching(self):
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
+ def test_dreambooth_lora_block(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --lora_blocks {self.transformer_block_idx}
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ # In this test, only params of transformer block 0 should be in the state dict
+ starts_with_transformer = all(
+ key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_layer(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --lora_layers {self.layer_type}
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # In this test, only transformer params of attention layers `attn.to_k` should be in the state dict
+ starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 5099107118e4..b863f5641233 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,7 +63,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -1300,16 +1300,17 @@ def compute_text_embeddings(prompt):
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
- torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
- )
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
+ divisor = snr + 1
else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
+ divisor = snr
+
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor
+ )
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py
index 29fd5e78535d..f38cb1098358 100644
--- a/examples/dreambooth/train_dreambooth_flax.py
+++ b/examples/dreambooth/train_dreambooth_flax.py
@@ -35,7 +35,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py
index 8e0f4e09a461..66f533e52a8a 100644
--- a/examples/dreambooth/train_dreambooth_flux.py
+++ b/examples/dreambooth/train_dreambooth_flux.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -57,6 +57,7 @@
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -64,10 +65,16 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
+if is_torch_npu_available():
+ import torch_npu
+
+ torch.npu.config.allow_internal_format = False
+ torch.npu.set_compile_mode(jit_compile=False)
+
def save_model_card(
repo_id: str,
@@ -161,11 +168,11 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
- pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
@@ -189,6 +196,8 @@ def log_validation(
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
+ elif is_torch_npu_available():
+ torch_npu.npu.empty_cache()
return images
@@ -1035,7 +1044,9 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
- has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ has_supported_fp16_accelerator = (
+ torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()
+ )
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
@@ -1073,6 +1084,8 @@ def main(args):
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
+ elif is_torch_npu_available():
+ torch_npu.npu.empty_cache()
# Handle the repository creation
if accelerator.is_main_process:
@@ -1226,10 +1239,7 @@ def load_model_hook(models, input_dir):
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
- params_to_optimize = [
- transformer_parameters_with_lr,
- text_parameters_one_with_lr,
- ]
+ params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
else:
params_to_optimize = [transformer_parameters_with_lr]
@@ -1288,11 +1298,9 @@ def load_model_hook(models, input_dir):
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
- params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -1359,6 +1367,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
+ elif is_torch_npu_available():
+ torch_npu.npu.empty_cache()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1540,12 +1550,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)
- vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
- model_input.shape[2],
- model_input.shape[3],
+ model_input.shape[2] // 2,
+ model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
@@ -1580,7 +1590,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)
# handle guidance
- if transformer.config.guidance_embeds:
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1601,8 +1611,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
model_pred = FluxPipeline._unpack_latents(
model_pred,
- height=int(model_input.shape[2] * vae_scale_factor / 2),
- width=int(model_input.shape[3] * vae_scale_factor / 2),
+ height=model_input.shape[2] * vae_scale_factor,
+ width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor,
)
@@ -1694,6 +1704,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ text_encoder_one.to(weight_dtype)
+ text_encoder_two.to(weight_dtype)
else: # even when training the text encoder we're only training text encoder one
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1704,9 +1716,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
- text_encoder=accelerator.unwrap_model(text_encoder_one),
- text_encoder_2=accelerator.unwrap_model(text_encoder_two),
- transformer=accelerator.unwrap_model(transformer),
+ text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
+ transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -1722,9 +1734,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
- torch.cuda.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif is_torch_npu_available():
+ torch_npu.npu.empty_cache()
gc.collect()
+ images = None
+ del pipeline
+
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
@@ -1783,6 +1801,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
ignore_patterns=["step_*", "epoch_*"],
)
+ images = None
+ del pipeline
+
accelerator.end_training()
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 5d7d697bb21d..9584e7762dbd 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -54,7 +54,11 @@
)
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
-from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
+from diffusers.training_utils import (
+ _set_state_dict_into_text_encoder,
+ cast_training_params,
+ free_memory,
+)
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
@@ -70,7 +74,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -146,19 +150,19 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
if args.validation_images is None:
images = []
for _ in range(args.num_validation_images):
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
- with torch.cuda.amp.autocast():
+ with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
@@ -177,7 +181,7 @@ def log_validation(
)
del pipeline
- torch.cuda.empty_cache()
+ free_memory()
return images
@@ -793,7 +797,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
- torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
@@ -829,8 +833,7 @@ def main(args):
image.save(image_filename)
del pipeline
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
+ free_memory()
# Handle the repository creation
if accelerator.is_main_process:
@@ -1085,7 +1088,7 @@ def compute_text_embeddings(prompt):
tokenizer = None
gc.collect()
- torch.cuda.empty_cache()
+ free_memory()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
@@ -1116,17 +1119,22 @@ def compute_text_embeddings(prompt):
)
# Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1143,8 +1151,15 @@ def compute_text_embeddings(prompt):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
+ if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index 11cba745cc4a..dda3300d65cc 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -72,7 +72,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -177,11 +177,11 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
- pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
@@ -554,6 +554,15 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
+ ),
+ )
+
parser.add_argument(
"--adam_epsilon",
type=float,
@@ -1186,12 +1195,30 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
- # now we will add new LoRA weights to the attention layers
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = [
+ "attn.to_k",
+ "attn.to_q",
+ "attn.to_v",
+ "attn.to_out.0",
+ "attn.add_k_proj",
+ "attn.add_q_proj",
+ "attn.add_v_proj",
+ "attn.to_add_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "ff_context.net.0.proj",
+ "ff_context.net.2",
+ ]
+
+ # now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
@@ -1308,10 +1335,7 @@ def load_model_hook(models, input_dir):
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
- params_to_optimize = [
- transformer_parameters_with_lr,
- text_parameters_one_with_lr,
- ]
+ params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
else:
params_to_optimize = [transformer_parameters_with_lr]
@@ -1367,14 +1391,12 @@ def load_model_hook(models, input_dir):
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
- # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # changes the learning rate of text_encoder_parameters_one to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
- params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -1626,11 +1648,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
prompt=prompts,
)
else:
+ elems_to_repeat = len(prompts)
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
- text_input_ids_list=[tokens_one, tokens_two],
+ text_input_ids_list=[
+ tokens_one.repeat(elems_to_repeat, 1),
+ tokens_two.repeat(elems_to_repeat, 1),
+ ],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=args.instance_prompt,
@@ -1645,12 +1671,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
- vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
+ vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
- model_input.shape[2],
- model_input.shape[3],
+ model_input.shape[2] // 2,
+ model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
@@ -1684,7 +1710,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)
# handle guidance
- if transformer.config.guidance_embeds:
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1704,8 +1730,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
- height=int(model_input.shape[2] * vae_scale_factor / 2),
- width=int(model_input.shape[3] * vae_scale_factor / 2),
+ height=model_input.shape[2] * vae_scale_factor,
+ width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor,
)
@@ -1797,6 +1823,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ text_encoder_one.to(weight_dtype)
+ text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
@@ -1820,6 +1848,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
del text_encoder_one, text_encoder_two
free_memory()
+ images = None
+ del pipeline
+
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
@@ -1884,6 +1915,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
ignore_patterns=["step_*", "epoch_*"],
)
+ images = None
+ del pipeline
+
accelerator.end_training()
diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py
new file mode 100644
index 000000000000..a8bf4e1cdc61
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py
@@ -0,0 +1,1563 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, Gemma2Model
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ Lumina2Text2ImgPipeline,
+ Lumina2Transformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ free_memory,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.33.0.dev0")
+
+logger = get_logger(__name__)
+
+if is_torch_npu_available():
+ torch.npu.config.allow_internal_format = False
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ instance_prompt=None,
+ system_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Lumina2 DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Lumina2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_lumina2.md).
+
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+The following `system_prompt` was also used used during training (ignore if `None`): {system_prompt}.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+TODO
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="apache-2.0",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "lumina2",
+ "lumina2-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {pipeline_args['prompt']}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=256,
+ help="Maximum sequence length to use with with the Gemma2 model",
+ )
+ parser.add_argument(
+ "--system_prompt",
+ type=str,
+ default=None,
+ help="System prompt to use during inference to give the Gemma2 model certain characteristics.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--final_validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lumina2-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only'
+ ),
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ pipeline = Lumina2Text2ImgPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ free_memory()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder = Gemma2Model.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = Lumina2Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ # keep VAE in FP32 to ensure numerical stability.
+ vae.to(dtype=torch.float32)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ # because Gemma2 is particularly suited for bfloat16.
+ text_encoder.to(dtype=torch.bfloat16)
+
+ # Initialize a text encoding pipeline and keep it to CPU for now.
+ text_encoding_pipeline = Lumina2Text2ImgPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=None,
+ transformer=None,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
+
+ # now we will add new LoRA weights the transformer layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ Lumina2Text2ImgPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = Lumina2Text2ImgPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ def compute_text_embeddings(prompt, text_encoding_pipeline):
+ text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
+ with torch.no_grad():
+ prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
+ prompt,
+ max_sequence_length=args.max_sequence_length,
+ system_prompt=args.system_prompt,
+ )
+ if args.offload:
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ prompt_embeds = prompt_embeds.to(transformer.dtype)
+ return prompt_embeds, prompt_attention_mask
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings(
+ args.instance_prompt, text_encoding_pipeline
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings(
+ args.class_prompt, text_encoding_pipeline
+ )
+
+ # Clear the memory here
+ if not train_dataset.custom_instance_prompts:
+ del text_encoder, tokenizer
+ free_memory()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ if not train_dataset.custom_instance_prompts:
+ prompt_embeds = instance_prompt_hidden_states
+ prompt_attention_mask = instance_prompt_attention_mask
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0)
+
+ vae_config_scaling_factor = vae.config.scaling_factor
+ vae_config_shift_factor = vae.config.shift_factor
+ if args.cache_latents:
+ latents_cache = []
+ vae = vae.to(accelerator.device)
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+
+ if args.validation_prompt is None:
+ del vae
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-lumina2-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ prompts = batch["prompts"]
+
+ with accelerator.accumulate(models_to_accumulate):
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline)
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].sample()
+ else:
+ vae = vae.to(accelerator.device)
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ if args.offload:
+ vae = vae.to("cpu")
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `model_input`
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input
+
+ # Predict the noise residual
+ # scale the timesteps (reversal not needed as we used a reverse lerp above already)
+ timesteps = timesteps / noise_scheduler.config.num_train_timesteps
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ encoder_hidden_states=prompt_embeds.repeat(len(prompts), 1, 1)
+ if not train_dataset.custom_instance_prompts
+ else prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask.repeat(len(prompts), 1)
+ if not train_dataset.custom_instance_prompts
+ else prompt_attention_mask,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss (reversed)
+ target = model_input - noise
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = Lumina2Text2ImgPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt, "system_prompt": args.system_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ )
+ free_memory()
+
+ images = None
+ del pipeline
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ Lumina2Text2ImgPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = Lumina2Text2ImgPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt):
+ prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
+ args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
+ pipeline_args = {"prompt": prompt_to_use, "system_prompt": args.system_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ system_prompt=args.system_prompt,
+ validation_prompt=validation_prpmpt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ images = None
+ del pipeline
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
new file mode 100644
index 000000000000..674cb0d1ad1e
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -0,0 +1,1569 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, Gemma2Model
+
+import diffusers
+from diffusers import (
+ AutoencoderDC,
+ FlowMatchEulerDiscreteScheduler,
+ SanaPipeline,
+ SanaTransformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ free_memory,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.33.0.dev0")
+
+logger = get_logger(__name__)
+
+if is_torch_npu_available():
+ torch.npu.config.allow_internal_format = False
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Sana DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Sana diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md).
+
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+TODO
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+TODO
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "sana",
+ "sana-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.enable_vae_tiling:
+ pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)
+
+ pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=300,
+ help="Maximum sequence length to use with with the Gemma model",
+ )
+ parser.add_argument(
+ "--complex_human_instruction",
+ type=str,
+ default=None,
+ help="Instructions for complex human attention: https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sana-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only'
+ ),
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch.float32,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
+ pipeline.transformer = pipeline.transformer.to(torch.float16)
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ free_memory()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder = Gemma2Model.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderDC.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = SanaTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ # VAE should always be kept in fp32 for SANA (?)
+ vae.to(dtype=torch.float32)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ # because Gemma2 is particularly suited for bfloat16.
+ text_encoder.to(dtype=torch.bfloat16)
+
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ for block in transformer.transformer_blocks:
+ block.attn2.set_use_npu_flash_attention(True)
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
+ # Initialize a text encoding pipeline and keep it to CPU for now.
+ text_encoding_pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=None,
+ transformer=None,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = ["to_k", "to_q", "to_v"]
+
+ # now we will add new LoRA weights the transformer layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ SanaPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = SanaPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ def compute_text_embeddings(prompt, text_encoding_pipeline):
+ text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
+ with torch.no_grad():
+ prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
+ prompt,
+ max_sequence_length=args.max_sequence_length,
+ complex_human_instruction=args.complex_human_instruction,
+ )
+ if args.offload:
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ prompt_embeds = prompt_embeds.to(transformer.dtype)
+ return prompt_embeds, prompt_attention_mask
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings(
+ args.instance_prompt, text_encoding_pipeline
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings(
+ args.class_prompt, text_encoding_pipeline
+ )
+
+ # Clear the memory here
+ if not train_dataset.custom_instance_prompts:
+ del text_encoder, tokenizer
+ free_memory()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ if not train_dataset.custom_instance_prompts:
+ prompt_embeds = instance_prompt_hidden_states
+ prompt_attention_mask = instance_prompt_attention_mask
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0)
+
+ vae_config_scaling_factor = vae.config.scaling_factor
+ if args.cache_latents:
+ latents_cache = []
+ vae = vae.to(accelerator.device)
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=vae.dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent)
+
+ if args.validation_prompt is None:
+ del vae
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-sana-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ with accelerator.accumulate(models_to_accumulate):
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline)
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step]
+ else:
+ vae = vae.to(accelerator.device)
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent
+ if args.offload:
+ vae = vae.to("cpu")
+ model_input = model_input * vae_config_scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=torch.float32,
+ )
+ pipeline_args = {
+ "prompt": args.validation_prompt,
+ "complex_human_instruction": args.complex_human_instruction,
+ }
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ )
+ free_memory()
+
+ images = None
+ del pipeline
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ SanaPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=torch.float32,
+ )
+ pipeline.transformer = pipeline.transformer.to(torch.float16)
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {
+ "prompt": args.validation_prompt,
+ "complex_human_instruction": args.complex_human_instruction,
+ }
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ images = None
+ del pipeline
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 8d0b6853eeec..4a08daaf61f7 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@
import torch
import torch.utils.checkpoint
import transformers
-from accelerate import Accelerator
+from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
@@ -72,7 +72,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -86,6 +86,15 @@ def save_model_card(
validation_prompt=None,
repo_folder=None,
):
+ if "large" in base_model:
+ model_variant = "SD3.5-Large"
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
+ variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
+ else:
+ model_variant = "SD3"
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
+ variant_tags = ["sd3", "sd3-diffusers"]
+
widget_dict = []
if images is not None:
for i, image in enumerate(images):
@@ -95,7 +104,7 @@ def save_model_card(
)
model_description = f"""
-# SD3 DreamBooth LoRA - {repo_id}
+# {model_variant} DreamBooth LoRA - {repo_id}
@@ -120,7 +129,7 @@ def save_model_card(
```py
from diffusers import AutoPipelineForText2Image
import torch
-pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
+pipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
@@ -135,7 +144,7 @@ def save_model_card(
## License
-Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
+Please adhere to the licensing terms as described [here]({license_url}).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
@@ -151,11 +160,11 @@ def save_model_card(
"diffusers-training",
"diffusers",
"lora",
- "sd3",
- "sd3-diffusers",
"template:sd-lora",
]
+ tags += variant_tags
+
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
@@ -190,7 +199,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
@@ -562,6 +571,25 @@ def parse_args(input_args=None):
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ "The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string."
+ "For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md"
+ ),
+ )
+ parser.add_argument(
+ "--lora_blocks",
+ type=str,
+ default=None,
+ help=(
+ "The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner."
+ 'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md'
+ ),
+ )
+
parser.add_argument(
"--adam_epsilon",
type=float,
@@ -1213,13 +1241,31 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = [
+ "attn.add_k_proj",
+ "attn.add_q_proj",
+ "attn.add_v_proj",
+ "attn.to_add_out",
+ "attn.to_k",
+ "attn.to_out.0",
+ "attn.to_q",
+ "attn.to_v",
+ ]
+ if args.lora_blocks is not None:
+ target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")]
+ target_modules = [
+ f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules
+ ]
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
@@ -1246,17 +1292,27 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None
for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ model = unwrap_model(model)
+ if args.upcast_before_saving:
+ model = model.to(torch.float32)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
- text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
- elif isinstance(model, type(unwrap_model(text_encoder_two))):
- text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
+ elif args.train_text_encoder and isinstance(
+ unwrap_model(model), type(unwrap_model(text_encoder_one))
+ ): # or text_encoder_two
+ # both text encoders are of the same class, so we check hidden size to distinguish between the two
+ model = unwrap_model(model)
+ hidden_size = model.config.hidden_size
+ if hidden_size == 768:
+ text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
+ elif hidden_size == 1280:
+ text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
- weights.pop()
+ if weights:
+ weights.pop()
StableDiffusion3Pipeline.save_lora_weights(
output_dir,
@@ -1270,17 +1326,31 @@ def load_model_hook(models, input_dir):
text_encoder_one_ = None
text_encoder_two_ = None
- while len(models) > 0:
- model = models.pop()
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
- text_encoder_one_ = model
- elif isinstance(model, type(unwrap_model(text_encoder_two))):
- text_encoder_two_ = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = unwrap_model(model)
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = unwrap_model(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ else:
+ transformer_ = SD3Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer"
+ )
+ transformer_.add_adapter(transformer_lora_config)
+ if args.train_text_encoder:
+ text_encoder_one_ = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder"
+ )
+ text_encoder_two_ = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2"
+ )
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
@@ -1422,7 +1492,6 @@ def load_model_hook(models, input_dir):
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -1781,7 +1850,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1
- if accelerator.is_main_process:
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 016464165c44..f0d993ad9bbc 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -67,6 +67,7 @@
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
+ is_peft_version,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
@@ -78,7 +79,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -202,17 +203,17 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
- pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
- autocast_ctx = torch.autocast(accelerator.device.type)
+ autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
@@ -1183,26 +1184,33 @@ def main(args):
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
+ def get_lora_config(rank, use_dora, target_modules):
+ base_config = {
+ "r": rank,
+ "lora_alpha": rank,
+ "init_lora_weights": "gaussian",
+ "target_modules": target_modules,
+ }
+ if use_dora:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ base_config["use_dora"] = True
+
+ return LoraConfig(**base_config)
+
# now we will add new LoRA weights to the attention layers
- unet_lora_config = LoraConfig(
- r=args.rank,
- use_dora=args.use_dora,
- lora_alpha=args.rank,
- init_lora_weights="gaussian",
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
- )
+ unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
+ unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
unet.add_adapter(unet_lora_config)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
- text_lora_config = LoraConfig(
- r=args.rank,
- use_dora=args.use_dora,
- lora_alpha=args.rank,
- init_lora_weights="gaussian",
- target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
- )
+ text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
+ text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
@@ -1402,7 +1410,6 @@ def load_model_hook(models, input_dir):
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py
index 455ba5a9293d..7a16b64e7d05 100644
--- a/examples/dreambooth/train_dreambooth_sd3.py
+++ b/examples/dreambooth/train_dreambooth_sd3.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,7 +63,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -77,6 +77,15 @@ def save_model_card(
validation_prompt=None,
repo_folder=None,
):
+ if "large" in base_model:
+ model_variant = "SD3.5-Large"
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
+ variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
+ else:
+ model_variant = "SD3"
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
+ variant_tags = ["sd3", "sd3-diffusers"]
+
widget_dict = []
if images is not None:
for i, image in enumerate(images):
@@ -86,7 +95,7 @@ def save_model_card(
)
model_description = f"""
-# SD3 DreamBooth - {repo_id}
+# {model_variant} DreamBooth - {repo_id}
@@ -113,7 +122,7 @@ def save_model_card(
## License
-Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
+Please adhere to the licensing terms as described `[here]({license_url})`.
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
@@ -128,10 +137,9 @@ def save_model_card(
"text-to-image",
"diffusers-training",
"diffusers",
- "sd3",
- "sd3-diffusers",
"template:sd-lora",
]
+ tags += variant_tags
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
@@ -167,7 +175,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
@@ -894,20 +902,26 @@ def _encode_prompt_with_clip(
tokenizer,
prompt: str,
device=None,
+ text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=77,
- truncation=True,
- return_tensors="pt",
- )
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
- text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
@@ -929,6 +943,7 @@ def encode_prompt(
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
+ text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -937,13 +952,14 @@ def encode_prompt(
clip_prompt_embeds_list = []
clip_pooled_prompt_embeds_list = []
- for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
+ for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
@@ -1320,7 +1336,6 @@ def load_model_hook(models, input_dir):
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
diff --git a/examples/flux-control/README.md b/examples/flux-control/README.md
new file mode 100644
index 000000000000..14afa499db0d
--- /dev/null
+++ b/examples/flux-control/README.md
@@ -0,0 +1,204 @@
+# Training Flux Control
+
+This (experimental) example shows how to train Control LoRAs with [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about Flux Control family, refer to the following resources:
+
+* [Docs](https://github.com/black-forest-labs/flux/blob/main/docs/structural-conditioning.md) by Black Forest Labs
+* Diffusers docs ([1](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#canny-control), [2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#depth-control))
+
+To incorporate additional condition latents, we expand the input features of Flux.1-Dev from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `x_embedder` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `FluxControlPipeline`.
+
+> [!NOTE]
+> **Gated model**
+>
+> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+
+```bash
+huggingface-cli login
+```
+
+The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
+
+```bash
+accelerate launch train_control_lora_flux.py \
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
+ --dataset_name="raulc0399/open_pose_controlnet" \
+ --output_dir="pose-control-lora" \
+ --mixed_precision="bf16" \
+ --train_batch_size=1 \
+ --rank=64 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=5000 \
+ --validation_image="openpose.png" \
+ --validation_prompt="A couple, 4k photo, highly detailed" \
+ --offload \
+ --seed="0" \
+ --push_to_hub
+```
+
+`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
+
+You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
+
+The training script exposes additional CLI args that might be useful to experiment with:
+
+* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
+* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
+* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
+
+### Training with DeepSpeed
+
+It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
+
+```yaml
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
+```
+
+And then while launching training, pass the config file:
+
+```bash
+accelerate launch --config_file=CONFIG_FILE.yaml ...
+```
+
+### Inference
+
+The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
+
+```bash
+pip install controlnet_aux
+```
+
+And then we are ready:
+
+```py
+from controlnet_aux import OpenposeDetector
+from diffusers import FluxControlPipeline
+from diffusers.utils import load_image
+from PIL import Image
+import numpy as np
+import torch
+
+pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
+pipe.load_lora_weights("...") # change this.
+
+open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
+
+# prepare pose condition.
+url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
+image = load_image(url)
+image = open_pose(image, detect_resolution=512, image_resolution=1024)
+image = np.array(image)[:, :, ::-1]
+image = Image.fromarray(np.uint8(image))
+
+prompt = "A couple, 4k photo, highly detailed"
+
+gen_images = pipe(
+ prompt=prompt,
+ control_image=image,
+ num_inference_steps=50,
+ joint_attention_kwargs={"scale": 0.9},
+ guidance_scale=25.,
+).images[0]
+gen_images.save("output.png")
+```
+
+## Full fine-tuning
+
+We provide a non-LoRA version of the training script `train_control_flux.py`. Here is an example command:
+
+```bash
+accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
+ --dataset_name="raulc0399/open_pose_controlnet" \
+ --output_dir="pose-control" \
+ --mixed_precision="bf16" \
+ --train_batch_size=2 \
+ --dataloader_num_workers=4 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --proportion_empty_prompts=0.2 \
+ --learning_rate=5e-5 \
+ --adam_weight_decay=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="cosine" \
+ --lr_warmup_steps=1000 \
+ --checkpointing_steps=1000 \
+ --max_train_steps=10000 \
+ --validation_steps=200 \
+ --validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
+ --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
+ --offload \
+ --seed="0" \
+ --push_to_hub
+```
+
+Change the `validation_image` and `validation_prompt` as needed.
+
+For inference, this time, we will run:
+
+```py
+from controlnet_aux import OpenposeDetector
+from diffusers import FluxControlPipeline, FluxTransformer2DModel
+from diffusers.utils import load_image
+from PIL import Image
+import numpy as np
+import torch
+
+transformer = FluxTransformer2DModel.from_pretrained("...") # change this.
+pipe = FluxControlPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
+).to("cuda")
+
+open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
+
+# prepare pose condition.
+url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
+image = load_image(url)
+image = open_pose(image, detect_resolution=512, image_resolution=1024)
+image = np.array(image)[:, :, ::-1]
+image = Image.fromarray(np.uint8(image))
+
+prompt = "A couple, 4k photo, highly detailed"
+
+gen_images = pipe(
+ prompt=prompt,
+ control_image=image,
+ num_inference_steps=50,
+ guidance_scale=25.,
+).images[0]
+gen_images.save("output.png")
+```
+
+## Things to note
+
+* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
+* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
+* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
\ No newline at end of file
diff --git a/examples/flux-control/requirements.txt b/examples/flux-control/requirements.txt
new file mode 100644
index 000000000000..6c5ec2e03f9a
--- /dev/null
+++ b/examples/flux-control/requirements.txt
@@ -0,0 +1,6 @@
+transformers==4.47.0
+wandb
+torch
+torchvision
+accelerate==1.2.0
+peft>=0.14.0
diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py
new file mode 100644
index 000000000000..d4dbc26a7e5c
--- /dev/null
+++ b/examples/flux-control/train_control_flux.py
@@ -0,0 +1,1249 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ free_memory,
+)
+from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.33.0.dev0")
+
+logger = get_logger(__name__)
+
+NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
+
+
+def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):
+ pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()
+ pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor
+ return pixel_latents.to(weight_dtype)
+
+
+def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):
+ logger.info("Running validation... ")
+
+ if not is_final_validation:
+ flux_transformer = accelerator.unwrap_model(flux_transformer)
+ pipeline = FluxControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=flux_transformer,
+ torch_dtype=weight_dtype,
+ )
+ else:
+ transformer = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
+ pipeline = FluxControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=transformer,
+ torch_dtype=weight_dtype,
+ )
+
+ pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+ if is_final_validation or torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = load_image(validation_image)
+ # maybe need to inference on 1024 to get a good image
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ prompt=validation_prompt,
+ control_image=validation_image,
+ num_inference_steps=50,
+ guidance_scale=args.guidance_scale,
+ generator=generator,
+ max_sequence_length=512,
+ height=args.resolution,
+ width=args.resolution,
+ ).images[0]
+ image = image.resize((args.resolution, args.resolution))
+ images.append(image)
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ formatted_images = []
+ formatted_images.append(np.asarray(validation_image))
+ for image in images:
+ formatted_images.append(np.asarray(image))
+ formatted_images = np.stack(formatted_images)
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+
+ elif tracker.name == "wandb":
+ formatted_images = []
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ free_memory()
+ return image_logs
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# flux-control-{repo_id}
+
+These are Control weights trained on {base_model} with new type of conditioning.
+{img_str}
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "flux",
+ "flux-diffusers",
+ "text-to-image",
+ "diffusers",
+ "control",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a Flux Control training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-control",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the control conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=1,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="flux_train_control",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument(
+ "--jsonl_for_train",
+ type=str,
+ default=None,
+ help="Path to the jsonl file containing the training data.",
+ )
+ parser.add_argument(
+ "--only_target_transformer_blocks",
+ action="store_true",
+ help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
+ )
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=30.0,
+ help="the guidance scale used for transformer.",
+ )
+
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.jsonl_for_train is None:
+ raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`")
+
+ if args.dataset_name is not None and args.jsonl_for_train is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
+ )
+
+ return args
+
+
+def get_train_dataset(args, accelerator):
+ dataset = None
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ if args.jsonl_for_train is not None:
+ # load from json
+ dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
+ dataset = dataset.flatten_indices()
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ with accelerator.main_process_first():
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(args.max_train_samples))
+ return train_dataset
+
+
+def prepare_train_dataset(dataset, accelerator):
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
+ for image in examples[args.image_column]
+ ]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
+ for image in examples[args.conditioning_image_column]
+ ]
+ conditioning_images = [image_transforms(image) for image in conditioning_images]
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+
+ is_caption_list = isinstance(examples[args.caption_column][0], list)
+ if is_caption_list:
+ examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
+ else:
+ examples["captions"] = list(examples[args.caption_column])
+
+ return examples
+
+ with accelerator.main_process_first():
+ dataset = dataset.with_transform(preprocess_train)
+
+ return dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+ captions = [example["captions"] for example in examples]
+ return {"pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "captions": captions}
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_out_dir = Path(args.output_dir, args.logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
+ if torch.backends.mps.is_available():
+ logger.info("MPS is enabled. Disabling AMP.")
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ # DEBUG, INFO, WARNING, ERROR, CRITICAL
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load models. We will load the text encoders later in a pipeline to compute
+ # embeddings.
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
+ flux_transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ logger.info("All models loaded successfully")
+
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="scheduler",
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ if not args.only_target_transformer_blocks:
+ flux_transformer.requires_grad_(True)
+ vae.requires_grad_(False)
+
+ # cast down and move to the CPU
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # let's not move the VAE to the GPU yet.
+ vae.to(dtype=torch.float32) # keep the VAE in float32.
+
+ # enable image inputs
+ with torch.no_grad():
+ initial_input_channels = flux_transformer.config.in_channels
+ new_linear = torch.nn.Linear(
+ flux_transformer.x_embedder.in_features * 2,
+ flux_transformer.x_embedder.out_features,
+ bias=flux_transformer.x_embedder.bias is not None,
+ dtype=flux_transformer.dtype,
+ device=flux_transformer.device,
+ )
+ new_linear.weight.zero_()
+ new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)
+ if flux_transformer.x_embedder.bias is not None:
+ new_linear.bias.copy_(flux_transformer.x_embedder.bias)
+ flux_transformer.x_embedder = new_linear
+
+ assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
+ flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
+
+ if args.only_target_transformer_blocks:
+ flux_transformer.x_embedder.requires_grad_(True)
+ for name, module in flux_transformer.named_modules():
+ if "transformer_blocks" in name:
+ module.requires_grad_(True)
+ else:
+ module.requirs_grad_(False)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):
+ model = unwrap_model(model)
+ model.save_pretrained(os.path.join(output_dir, "transformer"))
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):
+ transformer_ = model # noqa: F841
+ else:
+ raise ValueError(f"unexpected save model: {unwrap_model(model).__class__}")
+
+ else:
+ transformer_ = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ flux_transformer.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimization parameters
+ optimizer = optimizer_class(
+ flux_transformer.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Prepare dataset and dataloader.
+ train_dataset = get_train_dataset(args, accelerator)
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+ # Prepare everything with our `accelerator`.
+ flux_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ flux_transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.
+ text_encoding_pipeline = FluxControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype
+ )
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ logger.info(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
+ logger.info("Logging some dataset samples.")
+ formatted_images = []
+ formatted_control_images = []
+ all_prompts = []
+ for i, batch in enumerate(train_dataloader):
+ images = (batch["pixel_values"] + 1) / 2
+ control_images = (batch["conditioning_pixel_values"] + 1) / 2
+ prompts = batch["captions"]
+
+ if len(formatted_images) > 10:
+ break
+
+ for img, control_img, prompt in zip(images, control_images, prompts):
+ formatted_images.append(img)
+ formatted_control_images.append(control_img)
+ all_prompts.append(prompt)
+
+ logged_artifacts = []
+ for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
+ logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
+ logged_artifacts.append(wandb.Image(img, caption=prompt))
+
+ wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
+ wandb_tracker[0].log({"dataset_samples": logged_artifacts})
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ flux_transformer.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(flux_transformer):
+ # Convert images to latent space
+ # vae encode
+ pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype)
+ control_latents = encode_images(
+ batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
+ )
+ if args.offload:
+ # offload vae to CPU.
+ vae.cpu()
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ bsz = pixel_latents.shape[0]
+ noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
+
+ # Add noise according to flow matching.
+ sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
+ noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
+ # Concatenate across channels.
+ # Question: Should we concatenate before adding noise?
+ concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
+
+ # pack the latents.
+ packed_noisy_model_input = FluxControlPipeline._pack_latents(
+ concatenated_noisy_model_input,
+ batch_size=bsz,
+ num_channels_latents=concatenated_noisy_model_input.shape[1],
+ height=concatenated_noisy_model_input.shape[2],
+ width=concatenated_noisy_model_input.shape[3],
+ )
+
+ # latent image ids for RoPE.
+ latent_image_ids = FluxControlPipeline._prepare_latent_image_ids(
+ bsz,
+ concatenated_noisy_model_input.shape[2] // 2,
+ concatenated_noisy_model_input.shape[3] // 2,
+ accelerator.device,
+ weight_dtype,
+ )
+
+ # handle guidance
+ if unwrap_model(flux_transformer).config.guidance_embeds:
+ guidance_vec = torch.full(
+ (bsz,),
+ args.guidance_scale,
+ device=noisy_model_input.device,
+ dtype=weight_dtype,
+ )
+ else:
+ guidance_vec = None
+
+ # text encoding.
+ captions = batch["captions"]
+ text_encoding_pipeline = text_encoding_pipeline.to("cuda")
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
+ captions, prompt_2=None
+ )
+ # this could be optimized by not having to do any text encoding and just
+ # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
+ if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
+ prompt_embeds.zero_()
+ pooled_prompt_embeds.zero_()
+ if args.offload:
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+
+ # Predict.
+ model_pred = flux_transformer(
+ hidden_states=packed_noisy_model_input,
+ timestep=timesteps / 1000,
+ guidance=guidance_vec,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+ model_pred = FluxControlPipeline._unpack_latents(
+ model_pred,
+ height=noisy_model_input.shape[2] * vae_scale_factor,
+ width=noisy_model_input.shape[3] * vae_scale_factor,
+ vae_scale_factor=vae_scale_factor,
+ )
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow-matching loss
+ target = noise - pixel_latents
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ params_to_clip = flux_transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ flux_transformer=flux_transformer,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ flux_transformer = unwrap_model(flux_transformer)
+ if args.upcast_before_saving:
+ flux_transformer.to(torch.float32)
+ flux_transformer.save_pretrained(args.output_dir)
+
+ del flux_transformer
+ del text_encoding_pipeline
+ del vae
+ free_memory()
+
+ # Run a final round of validation.
+ image_logs = None
+ if args.validation_prompt is not None:
+ image_logs = log_validation(
+ flux_transformer=None,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*", "checkpoint-*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
new file mode 100644
index 000000000000..56c5f2a89a3a
--- /dev/null
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -0,0 +1,1406 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ free_memory,
+)
+from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.33.0.dev0")
+
+logger = get_logger(__name__)
+
+NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
+
+
+def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):
+ pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()
+ pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor
+ return pixel_latents.to(weight_dtype)
+
+
+def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):
+ logger.info("Running validation... ")
+
+ if not is_final_validation:
+ flux_transformer = accelerator.unwrap_model(flux_transformer)
+ pipeline = FluxControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=flux_transformer,
+ torch_dtype=weight_dtype,
+ )
+ else:
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=weight_dtype
+ )
+ initial_channels = transformer.config.in_channels
+ pipeline = FluxControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=transformer,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.load_lora_weights(args.output_dir)
+ assert (
+ pipeline.transformer.config.in_channels == initial_channels * 2
+ ), f"{pipeline.transformer.config.in_channels=}"
+
+ pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+ if is_final_validation or torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = load_image(validation_image)
+ # maybe need to inference on 1024 to get a good image
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ prompt=validation_prompt,
+ control_image=validation_image,
+ num_inference_steps=50,
+ guidance_scale=args.guidance_scale,
+ generator=generator,
+ max_sequence_length=512,
+ height=args.resolution,
+ width=args.resolution,
+ ).images[0]
+ image = image.resize((args.resolution, args.resolution))
+ images.append(image)
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ formatted_images = []
+ formatted_images.append(np.asarray(validation_image))
+ for image in images:
+ formatted_images.append(np.asarray(image))
+ formatted_images = np.stack(formatted_images)
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+
+ elif tracker.name == "wandb":
+ formatted_images = []
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ free_memory()
+ return image_logs
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# control-lora-{repo_id}
+
+These are Control LoRA weights trained on {base_model} with new type of conditioning.
+{img_str}
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "flux",
+ "flux-diffusers",
+ "text-to-image",
+ "diffusers",
+ "control-lora",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a Control LoRA training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="control-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument("--use_lora_bias", action="store_true", help="If training the bias of lora_B layers.")
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
+ ),
+ )
+ parser.add_argument(
+ "--gaussian_init_lora",
+ action="store_true",
+ help="If using the Gaussian init strategy. When False, we follow the original LoRA init strategy.",
+ )
+ parser.add_argument("--train_norm_layers", action="store_true", help="Whether to train the norm scales.")
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the control conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=1,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="flux_train_control_lora",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument(
+ "--jsonl_for_train",
+ type=str,
+ default=None,
+ help="Path to the jsonl file containing the training data.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=30.0,
+ help="the guidance scale used for transformer.",
+ )
+
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.jsonl_for_train is None:
+ raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`")
+
+ if args.dataset_name is not None and args.jsonl_for_train is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
+ )
+
+ return args
+
+
+def get_train_dataset(args, accelerator):
+ dataset = None
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ if args.jsonl_for_train is not None:
+ # load from json
+ dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
+ dataset = dataset.flatten_indices()
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ with accelerator.main_process_first():
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(args.max_train_samples))
+ return train_dataset
+
+
+def prepare_train_dataset(dataset, accelerator):
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
+ for image in examples[args.image_column]
+ ]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [
+ (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
+ for image in examples[args.conditioning_image_column]
+ ]
+ conditioning_images = [image_transforms(image) for image in conditioning_images]
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+
+ is_caption_list = isinstance(examples[args.caption_column][0], list)
+ if is_caption_list:
+ examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
+ else:
+ examples["captions"] = list(examples[args.caption_column])
+
+ return examples
+
+ with accelerator.main_process_first():
+ dataset = dataset.with_transform(preprocess_train)
+
+ return dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+ captions = [example["captions"] for example in examples]
+ return {"pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "captions": captions}
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+ if args.use_lora_bias and args.gaussian_init_lora:
+ raise ValueError("`gaussian` LoRA init scheme isn't supported when `use_lora_bias` is True.")
+
+ logging_out_dir = Path(args.output_dir, args.logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
+ if torch.backends.mps.is_available():
+ logger.info("MPS is enabled. Disabling AMP.")
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ # DEBUG, INFO, WARNING, ERROR, CRITICAL
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load models. We will load the text encoders later in a pipeline to compute
+ # embeddings.
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
+ flux_transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ logger.info("All models loaded successfully")
+
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="scheduler",
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae.requires_grad_(False)
+ flux_transformer.requires_grad_(False)
+
+ # cast down and move to the CPU
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # let's not move the VAE to the GPU yet.
+ vae.to(dtype=torch.float32) # keep the VAE in float32.
+ flux_transformer.to(dtype=weight_dtype, device=accelerator.device)
+
+ # enable image inputs
+ with torch.no_grad():
+ initial_input_channels = flux_transformer.config.in_channels
+ new_linear = torch.nn.Linear(
+ flux_transformer.x_embedder.in_features * 2,
+ flux_transformer.x_embedder.out_features,
+ bias=flux_transformer.x_embedder.bias is not None,
+ dtype=flux_transformer.dtype,
+ device=flux_transformer.device,
+ )
+ new_linear.weight.zero_()
+ new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)
+ if flux_transformer.x_embedder.bias is not None:
+ new_linear.bias.copy_(flux_transformer.x_embedder.bias)
+ flux_transformer.x_embedder = new_linear
+
+ assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
+ flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
+
+ if args.train_norm_layers:
+ for name, param in flux_transformer.named_parameters():
+ if any(k in name for k in NORM_LAYER_PREFIXES):
+ param.requires_grad = True
+
+ if args.lora_layers is not None:
+ if args.lora_layers != "all-linear":
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ # add the input layer to the mix.
+ if "x_embedder" not in target_modules:
+ target_modules.append("x_embedder")
+ elif args.lora_layers == "all-linear":
+ target_modules = set()
+ for name, module in flux_transformer.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ target_modules.add(name)
+ target_modules = list(target_modules)
+ else:
+ target_modules = [
+ "x_embedder",
+ "attn.to_k",
+ "attn.to_q",
+ "attn.to_v",
+ "attn.to_out.0",
+ "attn.add_k_proj",
+ "attn.add_q_proj",
+ "attn.add_v_proj",
+ "attn.to_add_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "ff_context.net.0.proj",
+ "ff_context.net.2",
+ ]
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian" if args.gaussian_init_lora else True,
+ target_modules=target_modules,
+ lora_bias=args.use_lora_bias,
+ )
+ flux_transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):
+ model = unwrap_model(model)
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ if args.train_norm_layers:
+ transformer_norm_layers_to_save = {
+ f"transformer.{name}": param
+ for name, param in model.named_parameters()
+ if any(k in name for k in NORM_LAYER_PREFIXES)
+ }
+ transformer_lora_layers_to_save = {
+ **transformer_lora_layers_to_save,
+ **transformer_norm_layers_to_save,
+ }
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ FluxControlPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(flux_transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+ else:
+ transformer_ = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer"
+ ).to(accelerator.device, weight_dtype)
+
+ # Handle input dimension doubling before adding adapter
+ with torch.no_grad():
+ initial_input_channels = transformer_.config.in_channels
+ new_linear = torch.nn.Linear(
+ transformer_.x_embedder.in_features * 2,
+ transformer_.x_embedder.out_features,
+ bias=transformer_.x_embedder.bias is not None,
+ dtype=transformer_.dtype,
+ device=transformer_.device,
+ )
+ new_linear.weight.zero_()
+ new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
+ if transformer_.x_embedder.bias is not None:
+ new_linear.bias.copy_(transformer_.x_embedder.bias)
+ transformer_.x_embedder = new_linear
+ transformer_.register_to_config(in_channels=initial_input_channels * 2)
+
+ transformer_.add_adapter(transformer_lora_config)
+
+ lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
+ transformer_lora_state_dict = {
+ f'{k.replace("transformer.", "")}': v
+ for k, v in lora_state_dict.items()
+ if k.startswith("transformer.") and "lora" in k
+ }
+ incompatible_keys = set_peft_model_state_dict(
+ transformer_, transformer_lora_state_dict, adapter_name="default"
+ )
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+ if args.train_norm_layers:
+ transformer_norm_state_dict = {
+ k: v
+ for k, v in lora_state_dict.items()
+ if k.startswith("transformer.") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES)
+ }
+ transformer_._transformer_norm_layers = FluxControlPipeline._load_norm_into_transformer(
+ transformer_norm_state_dict,
+ transformer=transformer_,
+ discard_original_layers=False,
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [flux_transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ if args.gradient_checkpointing:
+ flux_transformer.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimization parameters
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, flux_transformer.parameters()))
+ optimizer = optimizer_class(
+ transformer_lora_parameters,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Prepare dataset and dataloader.
+ train_dataset = get_train_dataset(args, accelerator)
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+ # Prepare everything with our `accelerator`.
+ flux_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ flux_transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.
+ text_encoding_pipeline = FluxControlPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype
+ )
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ logger.info(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
+ logger.info("Logging some dataset samples.")
+ formatted_images = []
+ formatted_control_images = []
+ all_prompts = []
+ for i, batch in enumerate(train_dataloader):
+ images = (batch["pixel_values"] + 1) / 2
+ control_images = (batch["conditioning_pixel_values"] + 1) / 2
+ prompts = batch["captions"]
+
+ if len(formatted_images) > 10:
+ break
+
+ for img, control_img, prompt in zip(images, control_images, prompts):
+ formatted_images.append(img)
+ formatted_control_images.append(control_img)
+ all_prompts.append(prompt)
+
+ logged_artifacts = []
+ for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
+ logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
+ logged_artifacts.append(wandb.Image(img, caption=prompt))
+
+ wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
+ wandb_tracker[0].log({"dataset_samples": logged_artifacts})
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ flux_transformer.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(flux_transformer):
+ # Convert images to latent space
+ # vae encode
+ pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype)
+ control_latents = encode_images(
+ batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
+ )
+
+ if args.offload:
+ # offload vae to CPU.
+ vae.cpu()
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ bsz = pixel_latents.shape[0]
+ noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
+
+ # Add noise according to flow matching.
+ sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
+ noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
+ # Concatenate across channels.
+ # Question: Should we concatenate before adding noise?
+ concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
+
+ # pack the latents.
+ packed_noisy_model_input = FluxControlPipeline._pack_latents(
+ concatenated_noisy_model_input,
+ batch_size=bsz,
+ num_channels_latents=concatenated_noisy_model_input.shape[1],
+ height=concatenated_noisy_model_input.shape[2],
+ width=concatenated_noisy_model_input.shape[3],
+ )
+
+ # latent image ids for RoPE.
+ latent_image_ids = FluxControlPipeline._prepare_latent_image_ids(
+ bsz,
+ concatenated_noisy_model_input.shape[2] // 2,
+ concatenated_noisy_model_input.shape[3] // 2,
+ accelerator.device,
+ weight_dtype,
+ )
+
+ # handle guidance
+ if unwrap_model(flux_transformer).config.guidance_embeds:
+ guidance_vec = torch.full(
+ (bsz,),
+ args.guidance_scale,
+ device=noisy_model_input.device,
+ dtype=weight_dtype,
+ )
+ else:
+ guidance_vec = None
+
+ # text encoding.
+ captions = batch["captions"]
+ text_encoding_pipeline = text_encoding_pipeline.to("cuda")
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
+ captions, prompt_2=None
+ )
+ # this could be optimized by not having to do any text encoding and just
+ # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
+ if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
+ prompt_embeds.zero_()
+ pooled_prompt_embeds.zero_()
+ if args.offload:
+ text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+
+ # Predict.
+ model_pred = flux_transformer(
+ hidden_states=packed_noisy_model_input,
+ timestep=timesteps / 1000,
+ guidance=guidance_vec,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+ model_pred = FluxControlPipeline._unpack_latents(
+ model_pred,
+ height=noisy_model_input.shape[2] * vae_scale_factor,
+ width=noisy_model_input.shape[3] * vae_scale_factor,
+ vae_scale_factor=vae_scale_factor,
+ )
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow-matching loss
+ target = noise - pixel_latents
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ params_to_clip = flux_transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ flux_transformer=flux_transformer,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ flux_transformer = unwrap_model(flux_transformer)
+ if args.upcast_before_saving:
+ flux_transformer.to(torch.float32)
+ transformer_lora_layers = get_peft_model_state_dict(flux_transformer)
+ if args.train_norm_layers:
+ transformer_norm_layers = {
+ f"transformer.{name}": param
+ for name, param in flux_transformer.named_parameters()
+ if any(k in name for k in NORM_LAYER_PREFIXES)
+ }
+ transformer_lora_layers = {**transformer_lora_layers, **transformer_norm_layers}
+ FluxControlPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ )
+
+ del flux_transformer
+ del text_encoding_pipeline
+ del vae
+ free_memory()
+
+ # Run a final round of validation.
+ image_logs = None
+ if args.validation_prompt is not None:
+ image_logs = log_validation(
+ flux_transformer=None,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*", "*.pt", "*.bin"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index 3cb0c6702599..d1caf281a2c5 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -695,7 +695,7 @@ def preprocess_images(examples):
)
# We need to ensure that the original and the edited images undergo the same
# augmentation transforms.
- images = np.concatenate([original_images, edited_images])
+ images = np.stack([original_images, edited_images])
images = torch.tensor(images)
images = 2 * (images / 255) - 1
return train_transforms(images)
@@ -706,7 +706,7 @@ def preprocess_train(examples):
# Since the original and edited images were concatenated before
# applying the transformations, we need to separate them and reshape
# them accordingly.
- original_images, edited_images = preprocessed_images.chunk(2)
+ original_images, edited_images = preprocessed_images
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index c88be6d16d88..5f01e2f2bb09 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -60,7 +60,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -766,7 +766,7 @@ def preprocess_images(examples):
)
# We need to ensure that the original and the edited images undergo the same
# augmentation transforms.
- images = np.concatenate([original_images, edited_images])
+ images = np.stack([original_images, edited_images])
images = torch.tensor(images)
images = 2 * (images / 255) - 1
return train_transforms(images)
@@ -906,7 +906,7 @@ def preprocess_train(examples):
# Since the original and edited images were concatenated before
# applying the transformations, we need to separate them and reshape
# them accordingly.
- original_images, edited_images = preprocessed_images.chunk(2)
+ original_images, edited_images = preprocessed_images
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index 9caa3694d636..5f5d79fa39f7 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 23f5d342b396..7bf19915210c 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index 6ed3377db131..af242cead065 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index 448429444448..5a112885b75a 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/model_search/README.md b/examples/model_search/README.md
new file mode 100644
index 000000000000..da7fb3358728
--- /dev/null
+++ b/examples/model_search/README.md
@@ -0,0 +1,155 @@
+# Search models on Civitai and Hugging Face
+
+The [auto_diffusers](https://github.com/suzukimain/auto_diffusers) library provides additional functionalities to Diffusers such as searching for models on Civitai and the Hugging Face Hub.
+Please refer to the original library [here](https://pypi.org/project/auto-diffusers/)
+
+## Installation
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+> [!IMPORTANT]
+> To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the installation up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment.
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+Set up the pipeline. You can also cd to this folder and run it.
+```bash
+!wget https://raw.githubusercontent.com/suzukimain/auto_diffusers/refs/heads/master/src/auto_diffusers/pipeline_easy.py
+```
+
+## Load from Civitai
+```python
+from pipeline_easy import (
+ EasyPipelineForText2Image,
+ EasyPipelineForImage2Image,
+ EasyPipelineForInpainting,
+)
+
+# Text-to-Image
+pipeline = EasyPipelineForText2Image.from_civitai(
+ "search_word",
+ base_model="SD 1.5",
+).to("cuda")
+
+
+# Image-to-Image
+pipeline = EasyPipelineForImage2Image.from_civitai(
+ "search_word",
+ base_model="SD 1.5",
+).to("cuda")
+
+
+# Inpainting
+pipeline = EasyPipelineForInpainting.from_civitai(
+ "search_word",
+ base_model="SD 1.5",
+).to("cuda")
+```
+
+## Load from Hugging Face
+```python
+from pipeline_easy import (
+ EasyPipelineForText2Image,
+ EasyPipelineForImage2Image,
+ EasyPipelineForInpainting,
+)
+
+# Text-to-Image
+pipeline = EasyPipelineForText2Image.from_huggingface(
+ "search_word",
+ checkpoint_format="diffusers",
+).to("cuda")
+
+
+# Image-to-Image
+pipeline = EasyPipelineForImage2Image.from_huggingface(
+ "search_word",
+ checkpoint_format="diffusers",
+).to("cuda")
+
+
+# Inpainting
+pipeline = EasyPipelineForInpainting.from_huggingface(
+ "search_word",
+ checkpoint_format="diffusers",
+).to("cuda")
+```
+
+
+## Search Civitai and Huggingface
+
+```python
+# Load Lora into the pipeline.
+pipeline.auto_load_lora_weights("Detail Tweaker")
+
+# Load TextualInversion into the pipeline.
+pipeline.auto_load_textual_inversion("EasyNegative", token="EasyNegative")
+```
+
+### Search Civitai
+
+> [!TIP]
+> **If an error occurs, insert the `token` and run again.**
+
+#### `EasyPipeline.from_civitai` parameters
+
+| Name | Type | Default | Description |
+|:---------------:|:----------------------:|:-------------:|:-----------------------------------------------------------------------------------:|
+| search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. |
+| model_type | string | `Checkpoint` | The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) |
+| base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) |
+| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. |
+| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |
+| cache_dir | string, Path | None | Path to the folder where cached files are stored. |
+| resume | bool | False | Whether to resume an incomplete download. |
+| token | string | None | API token for Civitai authentication. |
+
+
+#### `search_civitai` parameters
+
+| Name | Type | Default | Description |
+|:---------------:|:--------------:|:-------------:|:-----------------------------------------------------------------------------------:|
+| search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. |
+| model_type | string | `Checkpoint` | The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) |
+| base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) |
+| download | bool | False | Whether to download the model. |
+| force_download | bool | False | Whether to force the download if the model already exists. |
+| cache_dir | string, Path | None | Path to the folder where cached files are stored. |
+| resume | bool | False | Whether to resume an incomplete download. |
+| token | string | None | API token for Civitai authentication. |
+| include_params | bool | False | Whether to include parameters in the returned data. |
+| skip_error | bool | False | Whether to skip errors and return None. |
+
+### Search Huggingface
+
+> [!TIP]
+> **If an error occurs, insert the `token` and run again.**
+
+#### `EasyPipeline.from_huggingface` parameters
+
+| Name | Type | Default | Description |
+|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:|
+| search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`/`). |
+| checkpoint_format | string | `single_file` | The format of the model checkpoint. ● `single_file` to search for `single file checkpoint` ●`diffusers` to search for `multifolder diffusers format checkpoint` |
+| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. |
+| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |
+| cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. |
+| token | string, bool | None | The token to use as HTTP bearer authorization for remote files. |
+
+
+#### `search_huggingface` parameters
+
+| Name | Type | Default | Description |
+|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:|
+| search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`/`). |
+| checkpoint_format | string | `single_file` | The format of the model checkpoint. ● `single_file` to search for `single file checkpoint` ●`diffusers` to search for `multifolder diffusers format checkpoint` |
+| pipeline_tag | string | None | Tag to filter models by pipeline. |
+| download | bool | False | Whether to download the model. |
+| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |
+| cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. |
+| token | string, bool | None | The token to use as HTTP bearer authorization for remote files. |
+| include_params | bool | False | Whether to include parameters in the returned data. |
+| skip_error | bool | False | Whether to skip errors and return None. |
diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py
new file mode 100644
index 000000000000..a8add8311006
--- /dev/null
+++ b/examples/model_search/pipeline_easy.py
@@ -0,0 +1,1911 @@
+# coding=utf-8
+# Copyright 2025 suzukimain
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import types
+from collections import OrderedDict
+from dataclasses import asdict, dataclass, field
+from typing import Dict, List, Optional, Union
+
+import requests
+import torch
+from huggingface_hub import hf_api, hf_hub_download
+from huggingface_hub.file_download import http_get
+from huggingface_hub.utils import validate_hf_hub_args
+
+from diffusers.loaders.single_file_utils import (
+ VALID_URL_PREFIXES,
+ _extract_repo_id_and_weights_name,
+ infer_diffusers_model_type,
+ load_single_file_checkpoint,
+)
+from diffusers.pipelines.animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline
+from diffusers.pipelines.auto_pipeline import (
+ AutoPipelineForImage2Image,
+ AutoPipelineForInpainting,
+ AutoPipelineForText2Image,
+)
+from diffusers.pipelines.controlnet import (
+ StableDiffusionControlNetImg2ImgPipeline,
+ StableDiffusionControlNetInpaintPipeline,
+ StableDiffusionControlNetPipeline,
+ StableDiffusionXLControlNetImg2ImgPipeline,
+ StableDiffusionXLControlNetPipeline,
+)
+from diffusers.pipelines.flux import FluxImg2ImgPipeline, FluxPipeline
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import (
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ StableDiffusionUpscalePipeline,
+)
+from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
+from diffusers.pipelines.stable_diffusion_xl import (
+ StableDiffusionXLImg2ImgPipeline,
+ StableDiffusionXLInpaintPipeline,
+ StableDiffusionXLPipeline,
+)
+from diffusers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING = OrderedDict(
+ [
+ ("animatediff_rgb", AnimateDiffPipeline),
+ ("animatediff_scribble", AnimateDiffPipeline),
+ ("animatediff_sdxl_beta", AnimateDiffSDXLPipeline),
+ ("animatediff_v1", AnimateDiffPipeline),
+ ("animatediff_v2", AnimateDiffPipeline),
+ ("animatediff_v3", AnimateDiffPipeline),
+ ("autoencoder-dc-f128c512", None),
+ ("autoencoder-dc-f32c32", None),
+ ("autoencoder-dc-f32c32-sana", None),
+ ("autoencoder-dc-f64c128", None),
+ ("controlnet", StableDiffusionControlNetPipeline),
+ ("controlnet_xl", StableDiffusionXLControlNetPipeline),
+ ("controlnet_xl_large", StableDiffusionXLControlNetPipeline),
+ ("controlnet_xl_mid", StableDiffusionXLControlNetPipeline),
+ ("controlnet_xl_small", StableDiffusionXLControlNetPipeline),
+ ("flux-depth", FluxPipeline),
+ ("flux-dev", FluxPipeline),
+ ("flux-fill", FluxPipeline),
+ ("flux-schnell", FluxPipeline),
+ ("hunyuan-video", None),
+ ("inpainting", None),
+ ("inpainting_v2", None),
+ ("ltx-video", None),
+ ("ltx-video-0.9.1", None),
+ ("mochi-1-preview", None),
+ ("playground-v2-5", StableDiffusionXLPipeline),
+ ("sd3", StableDiffusion3Pipeline),
+ ("sd35_large", StableDiffusion3Pipeline),
+ ("sd35_medium", StableDiffusion3Pipeline),
+ ("stable_cascade_stage_b", None),
+ ("stable_cascade_stage_b_lite", None),
+ ("stable_cascade_stage_c", None),
+ ("stable_cascade_stage_c_lite", None),
+ ("upscale", StableDiffusionUpscalePipeline),
+ ("v1", StableDiffusionPipeline),
+ ("v2", StableDiffusionPipeline),
+ ("xl_base", StableDiffusionXLPipeline),
+ ("xl_inpaint", None),
+ ("xl_refiner", StableDiffusionXLPipeline),
+ ]
+)
+
+SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING = OrderedDict(
+ [
+ ("animatediff_rgb", AnimateDiffPipeline),
+ ("animatediff_scribble", AnimateDiffPipeline),
+ ("animatediff_sdxl_beta", AnimateDiffSDXLPipeline),
+ ("animatediff_v1", AnimateDiffPipeline),
+ ("animatediff_v2", AnimateDiffPipeline),
+ ("animatediff_v3", AnimateDiffPipeline),
+ ("autoencoder-dc-f128c512", None),
+ ("autoencoder-dc-f32c32", None),
+ ("autoencoder-dc-f32c32-sana", None),
+ ("autoencoder-dc-f64c128", None),
+ ("controlnet", StableDiffusionControlNetImg2ImgPipeline),
+ ("controlnet_xl", StableDiffusionXLControlNetImg2ImgPipeline),
+ ("controlnet_xl_large", StableDiffusionXLControlNetImg2ImgPipeline),
+ ("controlnet_xl_mid", StableDiffusionXLControlNetImg2ImgPipeline),
+ ("controlnet_xl_small", StableDiffusionXLControlNetImg2ImgPipeline),
+ ("flux-depth", FluxImg2ImgPipeline),
+ ("flux-dev", FluxImg2ImgPipeline),
+ ("flux-fill", FluxImg2ImgPipeline),
+ ("flux-schnell", FluxImg2ImgPipeline),
+ ("hunyuan-video", None),
+ ("inpainting", None),
+ ("inpainting_v2", None),
+ ("ltx-video", None),
+ ("ltx-video-0.9.1", None),
+ ("mochi-1-preview", None),
+ ("playground-v2-5", StableDiffusionXLImg2ImgPipeline),
+ ("sd3", StableDiffusion3Img2ImgPipeline),
+ ("sd35_large", StableDiffusion3Img2ImgPipeline),
+ ("sd35_medium", StableDiffusion3Img2ImgPipeline),
+ ("stable_cascade_stage_b", None),
+ ("stable_cascade_stage_b_lite", None),
+ ("stable_cascade_stage_c", None),
+ ("stable_cascade_stage_c_lite", None),
+ ("upscale", StableDiffusionUpscalePipeline),
+ ("v1", StableDiffusionImg2ImgPipeline),
+ ("v2", StableDiffusionImg2ImgPipeline),
+ ("xl_base", StableDiffusionXLImg2ImgPipeline),
+ ("xl_inpaint", None),
+ ("xl_refiner", StableDiffusionXLImg2ImgPipeline),
+ ]
+)
+
+SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING = OrderedDict(
+ [
+ ("animatediff_rgb", None),
+ ("animatediff_scribble", None),
+ ("animatediff_sdxl_beta", None),
+ ("animatediff_v1", None),
+ ("animatediff_v2", None),
+ ("animatediff_v3", None),
+ ("autoencoder-dc-f128c512", None),
+ ("autoencoder-dc-f32c32", None),
+ ("autoencoder-dc-f32c32-sana", None),
+ ("autoencoder-dc-f64c128", None),
+ ("controlnet", StableDiffusionControlNetInpaintPipeline),
+ ("controlnet_xl", None),
+ ("controlnet_xl_large", None),
+ ("controlnet_xl_mid", None),
+ ("controlnet_xl_small", None),
+ ("flux-depth", None),
+ ("flux-dev", None),
+ ("flux-fill", None),
+ ("flux-schnell", None),
+ ("hunyuan-video", None),
+ ("inpainting", StableDiffusionInpaintPipeline),
+ ("inpainting_v2", StableDiffusionInpaintPipeline),
+ ("ltx-video", None),
+ ("ltx-video-0.9.1", None),
+ ("mochi-1-preview", None),
+ ("playground-v2-5", None),
+ ("sd3", None),
+ ("sd35_large", None),
+ ("sd35_medium", None),
+ ("stable_cascade_stage_b", None),
+ ("stable_cascade_stage_b_lite", None),
+ ("stable_cascade_stage_c", None),
+ ("stable_cascade_stage_c_lite", None),
+ ("upscale", StableDiffusionUpscalePipeline),
+ ("v1", None),
+ ("v2", None),
+ ("xl_base", None),
+ ("xl_inpaint", StableDiffusionXLInpaintPipeline),
+ ("xl_refiner", None),
+ ]
+)
+
+
+CONFIG_FILE_LIST = [
+ "pytorch_model.bin",
+ "pytorch_model.fp16.bin",
+ "diffusion_pytorch_model.bin",
+ "diffusion_pytorch_model.fp16.bin",
+ "diffusion_pytorch_model.safetensors",
+ "diffusion_pytorch_model.fp16.safetensors",
+ "diffusion_pytorch_model.ckpt",
+ "diffusion_pytorch_model.fp16.ckpt",
+ "diffusion_pytorch_model.non_ema.bin",
+ "diffusion_pytorch_model.non_ema.safetensors",
+]
+
+DIFFUSERS_CONFIG_DIR = [
+ "safety_checker",
+ "unet",
+ "vae",
+ "text_encoder",
+ "text_encoder_2",
+]
+
+TOKENIZER_SHAPE_MAP = {
+ 768: [
+ "SD 1.4",
+ "SD 1.5",
+ "SD 1.5 LCM",
+ "SDXL 0.9",
+ "SDXL 1.0",
+ "SDXL 1.0 LCM",
+ "SDXL Distilled",
+ "SDXL Turbo",
+ "SDXL Lightning",
+ "PixArt a",
+ "Playground v2",
+ "Pony",
+ ],
+ 1024: ["SD 2.0", "SD 2.0 768", "SD 2.1", "SD 2.1 768", "SD 2.1 Unclip"],
+}
+
+
+EXTENSION = [".safetensors", ".ckpt", ".bin"]
+
+CACHE_HOME = os.path.expanduser("~/.cache")
+
+
+@dataclass
+class RepoStatus:
+ r"""
+ Data class for storing repository status information.
+
+ Attributes:
+ repo_id (`str`):
+ The name of the repository.
+ repo_hash (`str`):
+ The hash of the repository.
+ version (`str`):
+ The version ID of the repository.
+ """
+
+ repo_id: str = ""
+ repo_hash: str = ""
+ version: str = ""
+
+
+@dataclass
+class ModelStatus:
+ r"""
+ Data class for storing model status information.
+
+ Attributes:
+ search_word (`str`):
+ The search word used to find the model.
+ download_url (`str`):
+ The URL to download the model.
+ file_name (`str`):
+ The name of the model file.
+ local (`bool`):
+ Whether the model exists locally
+ site_url (`str`):
+ The URL of the site where the model is hosted.
+ """
+
+ search_word: str = ""
+ download_url: str = ""
+ file_name: str = ""
+ local: bool = False
+ site_url: str = ""
+
+
+@dataclass
+class ExtraStatus:
+ r"""
+ Data class for storing extra status information.
+
+ Attributes:
+ trained_words (`str`):
+ The words used to trigger the model
+ """
+
+ trained_words: Union[List[str], None] = None
+
+
+@dataclass
+class SearchResult:
+ r"""
+ Data class for storing model data.
+
+ Attributes:
+ model_path (`str`):
+ The path to the model.
+ loading_method (`str`):
+ The type of loading method used for the model ( None or 'from_single_file' or 'from_pretrained')
+ checkpoint_format (`str`):
+ The format of the model checkpoint (`single_file` or `diffusers`).
+ repo_status (`RepoStatus`):
+ The status of the repository.
+ model_status (`ModelStatus`):
+ The status of the model.
+ """
+
+ model_path: str = ""
+ loading_method: Union[str, None] = None
+ checkpoint_format: Union[str, None] = None
+ repo_status: RepoStatus = field(default_factory=RepoStatus)
+ model_status: ModelStatus = field(default_factory=ModelStatus)
+ extra_status: ExtraStatus = field(default_factory=ExtraStatus)
+
+
+@validate_hf_hub_args
+def load_pipeline_from_single_file(pretrained_model_or_path, pipeline_mapping, **kwargs):
+ r"""
+ Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
+ format. The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+ - A link to the `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
+ - A path to a *file* containing all pipeline weights.
+ pipeline_mapping (`dict`):
+ A mapping of model types to their corresponding pipeline classes. This is used to determine
+ which pipeline class to instantiate based on the model type inferred from the checkpoint.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ original_config_file (`str`, *optional*):
+ The path to the original config file that was used to train the model. If not provided, the config file
+ will be inferred from the checkpoint file.
+ config (`str`, *optional*):
+ Can be either:
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
+ component configs in Diffusers format.
+ checkpoint (`dict`, *optional*):
+ The loaded state dictionary of the model.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+ """
+
+ # Load the checkpoint from the provided link or path
+ checkpoint = load_single_file_checkpoint(pretrained_model_or_path)
+
+ # Infer the model type from the loaded checkpoint
+ model_type = infer_diffusers_model_type(checkpoint)
+
+ # Get the corresponding pipeline class from the pipeline mapping
+ pipeline_class = pipeline_mapping[model_type]
+
+ # For tasks not supported by this pipeline
+ if pipeline_class is None:
+ raise ValueError(
+ f"{model_type} is not supported in this pipeline."
+ "For `Text2Image`, please use `AutoPipelineForText2Image.from_pretrained`, "
+ "for `Image2Image` , please use `AutoPipelineForImage2Image.from_pretrained`, "
+ "and `inpaint` is only supported in `AutoPipelineForInpainting.from_pretrained`"
+ )
+
+ else:
+ # Instantiate and return the pipeline with the loaded checkpoint and any additional kwargs
+ return pipeline_class.from_single_file(pretrained_model_or_path, **kwargs)
+
+
+def get_keyword_types(keyword):
+ r"""
+ Determine the type and loading method for a given keyword.
+
+ Parameters:
+ keyword (`str`):
+ The input keyword to classify.
+
+ Returns:
+ `dict`: A dictionary containing the model format, loading method,
+ and various types and extra types flags.
+ """
+
+ # Initialize the status dictionary with default values
+ status = {
+ "checkpoint_format": None,
+ "loading_method": None,
+ "type": {
+ "other": False,
+ "hf_url": False,
+ "hf_repo": False,
+ "civitai_url": False,
+ "local": False,
+ },
+ "extra_type": {
+ "url": False,
+ "missing_model_index": None,
+ },
+ }
+
+ # Check if the keyword is an HTTP or HTTPS URL
+ status["extra_type"]["url"] = bool(re.search(r"^(https?)://", keyword))
+
+ # Check if the keyword is a file
+ if os.path.isfile(keyword):
+ status["type"]["local"] = True
+ status["checkpoint_format"] = "single_file"
+ status["loading_method"] = "from_single_file"
+
+ # Check if the keyword is a directory
+ elif os.path.isdir(keyword):
+ status["type"]["local"] = True
+ status["checkpoint_format"] = "diffusers"
+ status["loading_method"] = "from_pretrained"
+ if not os.path.exists(os.path.join(keyword, "model_index.json")):
+ status["extra_type"]["missing_model_index"] = True
+
+ # Check if the keyword is a Civitai URL
+ elif keyword.startswith("https://civitai.com/"):
+ status["type"]["civitai_url"] = True
+ status["checkpoint_format"] = "single_file"
+ status["loading_method"] = None
+
+ # Check if the keyword starts with any valid URL prefixes
+ elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES):
+ repo_id, weights_name = _extract_repo_id_and_weights_name(keyword)
+ if weights_name:
+ status["type"]["hf_url"] = True
+ status["checkpoint_format"] = "single_file"
+ status["loading_method"] = "from_single_file"
+ else:
+ status["type"]["hf_repo"] = True
+ status["checkpoint_format"] = "diffusers"
+ status["loading_method"] = "from_pretrained"
+
+ # Check if the keyword matches a Hugging Face repository format
+ elif re.match(r"^[^/]+/[^/]+$", keyword):
+ status["type"]["hf_repo"] = True
+ status["checkpoint_format"] = "diffusers"
+ status["loading_method"] = "from_pretrained"
+
+ # If none of the above apply
+ else:
+ status["type"]["other"] = True
+ status["checkpoint_format"] = None
+ status["loading_method"] = None
+
+ return status
+
+
+def file_downloader(
+ url,
+ save_path,
+ **kwargs,
+) -> None:
+ """
+ Downloads a file from a given URL and saves it to the specified path.
+
+ parameters:
+ url (`str`):
+ The URL of the file to download.
+ save_path (`str`):
+ The local path where the file will be saved.
+ resume (`bool`, *optional*, defaults to `False`):
+ Whether to resume an incomplete download.
+ headers (`dict`, *optional*, defaults to `None`):
+ Dictionary of HTTP Headers to send with the request.
+ proxies (`dict`, *optional*, defaults to `None`):
+ Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether to force the download even if the file already exists.
+ displayed_filename (`str`, *optional*):
+ The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If
+ not set, the filename is guessed from the URL or the `Content-Disposition` header.
+
+ returns:
+ None
+ """
+
+ # Get optional parameters from kwargs, with their default values
+ resume = kwargs.pop("resume", False)
+ headers = kwargs.pop("headers", None)
+ proxies = kwargs.pop("proxies", None)
+ force_download = kwargs.pop("force_download", False)
+ displayed_filename = kwargs.pop("displayed_filename", None)
+
+ # Default mode for file writing and initial file size
+ mode = "wb"
+ file_size = 0
+
+ # Create directory
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ # Check if the file already exists at the save path
+ if os.path.exists(save_path):
+ if not force_download:
+ # If the file exists and force_download is False, skip the download
+ logger.info(f"File already exists: {save_path}, skipping download.")
+ return None
+ elif resume:
+ # If resuming, set mode to append binary and get current file size
+ mode = "ab"
+ file_size = os.path.getsize(save_path)
+
+ # Open the file in the appropriate mode (write or append)
+ with open(save_path, mode) as model_file:
+ # Call the http_get function to perform the file download
+ return http_get(
+ url=url,
+ temp_file=model_file,
+ resume_size=file_size,
+ displayed_filename=displayed_filename,
+ headers=headers,
+ proxies=proxies,
+ **kwargs,
+ )
+
+
+def search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, None]:
+ r"""
+ Downloads a model from Hugging Face.
+
+ Parameters:
+ search_word (`str`):
+ The search query string.
+ revision (`str`, *optional*):
+ The specific version of the model to download.
+ checkpoint_format (`str`, *optional*, defaults to `"single_file"`):
+ The format of the model checkpoint.
+ download (`bool`, *optional*, defaults to `False`):
+ Whether to download the model.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether to force the download if the model already exists.
+ include_params (`bool`, *optional*, defaults to `False`):
+ Whether to include parameters in the returned data.
+ pipeline_tag (`str`, *optional*):
+ Tag to filter models by pipeline.
+ token (`str`, *optional*):
+ API token for Hugging Face authentication.
+ gated (`bool`, *optional*, defaults to `False` ):
+ A boolean to filter models on the Hub that are gated or not.
+ skip_error (`bool`, *optional*, defaults to `False`):
+ Whether to skip errors and return None.
+
+ Returns:
+ `Union[str, SearchResult, None]`: The model path or SearchResult or None.
+ """
+ # Extract additional parameters from kwargs
+ revision = kwargs.pop("revision", None)
+ checkpoint_format = kwargs.pop("checkpoint_format", "single_file")
+ download = kwargs.pop("download", False)
+ force_download = kwargs.pop("force_download", False)
+ include_params = kwargs.pop("include_params", False)
+ pipeline_tag = kwargs.pop("pipeline_tag", None)
+ token = kwargs.pop("token", None)
+ gated = kwargs.pop("gated", False)
+ skip_error = kwargs.pop("skip_error", False)
+
+ file_list = []
+ hf_repo_info = {}
+ hf_security_info = {}
+ model_path = ""
+ repo_id, file_name = "", ""
+ diffusers_model_exists = False
+
+ # Get the type and loading method for the keyword
+ search_word_status = get_keyword_types(search_word)
+
+ if search_word_status["type"]["hf_repo"]:
+ hf_repo_info = hf_api.model_info(repo_id=search_word, securityStatus=True)
+ if download:
+ model_path = DiffusionPipeline.download(
+ search_word,
+ revision=revision,
+ token=token,
+ force_download=force_download,
+ **kwargs,
+ )
+ else:
+ model_path = search_word
+ elif search_word_status["type"]["hf_url"]:
+ repo_id, weights_name = _extract_repo_id_and_weights_name(search_word)
+ if download:
+ model_path = hf_hub_download(
+ repo_id=repo_id,
+ filename=weights_name,
+ force_download=force_download,
+ token=token,
+ )
+ else:
+ model_path = search_word
+ elif search_word_status["type"]["local"]:
+ model_path = search_word
+ elif search_word_status["type"]["civitai_url"]:
+ if skip_error:
+ return None
+ else:
+ raise ValueError("The URL for Civitai is invalid with `for_hf`. Please use `for_civitai` instead.")
+ else:
+ # Get model data from HF API
+ hf_models = hf_api.list_models(
+ search=search_word,
+ direction=-1,
+ limit=100,
+ fetch_config=True,
+ pipeline_tag=pipeline_tag,
+ full=True,
+ gated=gated,
+ token=token,
+ )
+ model_dicts = [asdict(value) for value in list(hf_models)]
+
+ # Loop through models to find a suitable candidate
+ for repo_info in model_dicts:
+ repo_id = repo_info["id"]
+ file_list = []
+ hf_repo_info = hf_api.model_info(repo_id=repo_id, securityStatus=True)
+ # Lists files with security issues.
+ hf_security_info = hf_repo_info.security_repo_status
+ exclusion = [issue["path"] for issue in hf_security_info["filesWithIssues"]]
+
+ # Checks for multi-folder diffusers model or valid files (models with security issues are excluded).
+ if hf_security_info["scansDone"]:
+ for info in repo_info["siblings"]:
+ file_path = info["rfilename"]
+ if "model_index.json" == file_path and checkpoint_format in [
+ "diffusers",
+ "all",
+ ]:
+ diffusers_model_exists = True
+ break
+
+ elif (
+ any(file_path.endswith(ext) for ext in EXTENSION)
+ and not any(config in file_path for config in CONFIG_FILE_LIST)
+ and not any(exc in file_path for exc in exclusion)
+ and os.path.basename(os.path.dirname(file_path)) not in DIFFUSERS_CONFIG_DIR
+ ):
+ file_list.append(file_path)
+
+ # Exit from the loop if a multi-folder diffusers model or valid file is found
+ if diffusers_model_exists or file_list:
+ break
+ else:
+ # Handle case where no models match the criteria
+ if skip_error:
+ return None
+ else:
+ raise ValueError("No models matching your criteria were found on huggingface.")
+
+ if diffusers_model_exists:
+ if download:
+ model_path = DiffusionPipeline.download(
+ repo_id,
+ token=token,
+ **kwargs,
+ )
+ else:
+ model_path = repo_id
+
+ elif file_list:
+ # Sort and find the safest model
+ file_name = next(
+ (model for model in sorted(file_list, reverse=True) if re.search(r"(?i)[-_](safe|sfw)", model)),
+ file_list[0],
+ )
+
+ if download:
+ model_path = hf_hub_download(
+ repo_id=repo_id,
+ filename=file_name,
+ revision=revision,
+ token=token,
+ force_download=force_download,
+ )
+
+ # `pathlib.PosixPath` may be returned
+ if model_path:
+ model_path = str(model_path)
+
+ if file_name:
+ download_url = f"https://huggingface.co/{repo_id}/blob/main/{file_name}"
+ else:
+ download_url = f"https://huggingface.co/{repo_id}"
+
+ output_info = get_keyword_types(model_path)
+
+ if include_params:
+ return SearchResult(
+ model_path=model_path or download_url,
+ loading_method=output_info["loading_method"],
+ checkpoint_format=output_info["checkpoint_format"],
+ repo_status=RepoStatus(repo_id=repo_id, repo_hash=hf_repo_info.sha, version=revision),
+ model_status=ModelStatus(
+ search_word=search_word,
+ site_url=download_url,
+ download_url=download_url,
+ file_name=file_name,
+ local=download,
+ ),
+ extra_status=ExtraStatus(trained_words=None),
+ )
+
+ else:
+ return model_path
+
+
+def search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]:
+ r"""
+ Downloads a model from Civitai.
+
+ Parameters:
+ search_word (`str`):
+ The search query string.
+ model_type (`str`, *optional*, defaults to `Checkpoint`):
+ The type of model to search for.
+ sort (`str`, *optional*):
+ The order in which you wish to sort the results(for example, `Highest Rated`, `Most Downloaded`, `Newest`).
+ base_model (`str`, *optional*):
+ The base model to filter by.
+ download (`bool`, *optional*, defaults to `False`):
+ Whether to download the model.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether to force the download if the model already exists.
+ token (`str`, *optional*):
+ API token for Civitai authentication.
+ include_params (`bool`, *optional*, defaults to `False`):
+ Whether to include parameters in the returned data.
+ cache_dir (`str`, `Path`, *optional*):
+ Path to the folder where cached files are stored.
+ resume (`bool`, *optional*, defaults to `False`):
+ Whether to resume an incomplete download.
+ skip_error (`bool`, *optional*, defaults to `False`):
+ Whether to skip errors and return None.
+
+ Returns:
+ `Union[str, SearchResult, None]`: The model path or ` SearchResult` or None.
+ """
+
+ # Extract additional parameters from kwargs
+ model_type = kwargs.pop("model_type", "Checkpoint")
+ sort = kwargs.pop("sort", None)
+ download = kwargs.pop("download", False)
+ base_model = kwargs.pop("base_model", None)
+ force_download = kwargs.pop("force_download", False)
+ token = kwargs.pop("token", None)
+ include_params = kwargs.pop("include_params", False)
+ resume = kwargs.pop("resume", False)
+ cache_dir = kwargs.pop("cache_dir", None)
+ skip_error = kwargs.pop("skip_error", False)
+
+ # Initialize additional variables with default values
+ model_path = ""
+ repo_name = ""
+ repo_id = ""
+ version_id = ""
+ trainedWords = ""
+ models_list = []
+ selected_repo = {}
+ selected_model = {}
+ selected_version = {}
+ civitai_cache_dir = cache_dir or os.path.join(CACHE_HOME, "Civitai")
+
+ # Set up parameters and headers for the CivitAI API request
+ params = {
+ "query": search_word,
+ "types": model_type,
+ "limit": 20,
+ }
+ if base_model is not None:
+ if not isinstance(base_model, list):
+ base_model = [base_model]
+ params["baseModel"] = base_model
+
+ if sort is not None:
+ params["sort"] = sort
+
+ headers = {}
+ if token:
+ headers["Authorization"] = f"Bearer {token}"
+
+ try:
+ # Make the request to the CivitAI API
+ response = requests.get("https://civitai.com/api/v1/models", params=params, headers=headers)
+ response.raise_for_status()
+ except requests.exceptions.HTTPError as err:
+ raise requests.HTTPError(f"Could not get elements from the URL: {err}")
+ else:
+ try:
+ data = response.json()
+ except AttributeError:
+ if skip_error:
+ return None
+ else:
+ raise ValueError("Invalid JSON response")
+
+ # Sort repositories by download count in descending order
+ sorted_repos = sorted(data["items"], key=lambda x: x["stats"]["downloadCount"], reverse=True)
+
+ for selected_repo in sorted_repos:
+ repo_name = selected_repo["name"]
+ repo_id = selected_repo["id"]
+
+ # Sort versions within the selected repo by download count
+ sorted_versions = sorted(
+ selected_repo["modelVersions"],
+ key=lambda x: x["stats"]["downloadCount"],
+ reverse=True,
+ )
+ for selected_version in sorted_versions:
+ version_id = selected_version["id"]
+ trainedWords = selected_version["trainedWords"]
+ models_list = []
+ # When searching for textual inversion, results other than the values entered for the base model may come up, so check again.
+ if base_model is None or selected_version["baseModel"] in base_model:
+ for model_data in selected_version["files"]:
+ # Check if the file passes security scans and has a valid extension
+ file_name = model_data["name"]
+ if (
+ model_data["pickleScanResult"] == "Success"
+ and model_data["virusScanResult"] == "Success"
+ and any(file_name.endswith(ext) for ext in EXTENSION)
+ and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR
+ ):
+ file_status = {
+ "filename": file_name,
+ "download_url": model_data["downloadUrl"],
+ }
+ models_list.append(file_status)
+
+ if models_list:
+ # Sort the models list by filename and find the safest model
+ sorted_models = sorted(models_list, key=lambda x: x["filename"], reverse=True)
+ selected_model = next(
+ (
+ model_data
+ for model_data in sorted_models
+ if bool(re.search(r"(?i)[-_](safe|sfw)", model_data["filename"]))
+ ),
+ sorted_models[0],
+ )
+
+ break
+ else:
+ continue
+ break
+
+ # Exception handling when search candidates are not found
+ if not selected_model:
+ if skip_error:
+ return None
+ else:
+ raise ValueError("No model found. Please try changing the word you are searching for.")
+
+ # Define model file status
+ file_name = selected_model["filename"]
+ download_url = selected_model["download_url"]
+
+ # Handle file download and setting model information
+ if download:
+ # The path where the model is to be saved.
+ model_path = os.path.join(str(civitai_cache_dir), str(repo_id), str(version_id), str(file_name))
+ # Download Model File
+ file_downloader(
+ url=download_url,
+ save_path=model_path,
+ resume=resume,
+ force_download=force_download,
+ displayed_filename=file_name,
+ headers=headers,
+ **kwargs,
+ )
+
+ else:
+ model_path = download_url
+
+ output_info = get_keyword_types(model_path)
+
+ if not include_params:
+ return model_path
+ else:
+ return SearchResult(
+ model_path=model_path,
+ loading_method=output_info["loading_method"],
+ checkpoint_format=output_info["checkpoint_format"],
+ repo_status=RepoStatus(repo_id=repo_name, repo_hash=repo_id, version=version_id),
+ model_status=ModelStatus(
+ search_word=search_word,
+ site_url=f"https://civitai.com/models/{repo_id}?modelVersionId={version_id}",
+ download_url=download_url,
+ file_name=file_name,
+ local=output_info["type"]["local"],
+ ),
+ extra_status=ExtraStatus(trained_words=trainedWords or None),
+ )
+
+
+def add_methods(pipeline):
+ r"""
+ Add methods from `AutoConfig` to the pipeline.
+
+ Parameters:
+ pipeline (`Pipeline`):
+ The pipeline to which the methods will be added.
+ """
+ for attr_name in dir(AutoConfig):
+ attr_value = getattr(AutoConfig, attr_name)
+ if callable(attr_value) and not attr_name.startswith("__"):
+ setattr(pipeline, attr_name, types.MethodType(attr_value, pipeline))
+ return pipeline
+
+
+class AutoConfig:
+ def auto_load_textual_inversion(
+ self,
+ pretrained_model_name_or_path: Union[str, List[str]],
+ token: Optional[Union[str, List[str]]] = None,
+ base_model: Optional[Union[str, List[str]]] = None,
+ tokenizer=None,
+ text_encoder=None,
+ **kwargs,
+ ):
+ r"""
+ Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
+ Automatic1111 formats are supported).
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
+ Can be either one of the following or a list of them:
+
+ - Search keywords for pretrained model (for example `EasyNegative`).
+ - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
+ pretrained model hosted on the Hub.
+ - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
+ inversion weights.
+ - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ token (`str` or `List[str]`, *optional*):
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
+ list, then `token` must also be a list of equal length.
+ text_encoder ([`~transformers.CLIPTextModel`], *optional*):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ If not specified, function will take self.tokenizer.
+ tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
+ A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
+ weight_name (`str`, *optional*):
+ Name of a custom weight file. This should be used when:
+
+ - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
+ name such as `text_inv.bin`.
+ - The saved textual inversion file is in the Automatic1111 format.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+
+ Examples:
+
+ ```py
+ >>> from auto_diffusers import EasyPipelineForText2Image
+
+ >>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
+
+ >>> pipeline.auto_load_textual_inversion("EasyNegative", token="EasyNegative")
+
+ >>> image = pipeline(prompt).images[0]
+ ```
+
+ """
+ # 1. Set tokenizer and text encoder
+ tokenizer = tokenizer or getattr(self, "tokenizer", None)
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
+
+ # Check if tokenizer and text encoder are provided
+ if tokenizer is None or text_encoder is None:
+ raise ValueError("Tokenizer and text encoder must be provided.")
+
+ # 2. Normalize inputs
+ pretrained_model_name_or_paths = (
+ [pretrained_model_name_or_path]
+ if not isinstance(pretrained_model_name_or_path, list)
+ else pretrained_model_name_or_path
+ )
+
+ # 2.1 Normalize tokens
+ tokens = [token] if not isinstance(token, list) else token
+ if tokens[0] is None:
+ tokens = tokens * len(pretrained_model_name_or_paths)
+
+ for check_token in tokens:
+ # Check if token is already in tokenizer vocabulary
+ if check_token in tokenizer.get_vocab():
+ raise ValueError(
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
+ )
+
+ expected_shape = text_encoder.get_input_embeddings().weight.shape[-1] # Expected shape of tokenizer
+
+ for search_word in pretrained_model_name_or_paths:
+ if isinstance(search_word, str):
+ # Update kwargs to ensure the model is downloaded and parameters are included
+ _status = {
+ "download": True,
+ "include_params": True,
+ "skip_error": False,
+ "model_type": "TextualInversion",
+ }
+ # Get tags for the base model of textual inversion compatible with tokenizer.
+ # If the tokenizer is 768-dimensional, set tags for SD 1.x and SDXL.
+ # If the tokenizer is 1024-dimensional, set tags for SD 2.x.
+ if expected_shape in TOKENIZER_SHAPE_MAP:
+ # Retrieve the appropriate tags from the TOKENIZER_SHAPE_MAP based on the expected shape
+ tags = TOKENIZER_SHAPE_MAP[expected_shape]
+ if base_model is not None:
+ if isinstance(base_model, list):
+ tags.extend(base_model)
+ else:
+ tags.append(base_model)
+ _status["base_model"] = tags
+
+ kwargs.update(_status)
+ # Search for the model on Civitai and get the model status
+ textual_inversion_path = search_civitai(search_word, **kwargs)
+ logger.warning(
+ f"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}"
+ )
+
+ pretrained_model_name_or_paths[
+ pretrained_model_name_or_paths.index(search_word)
+ ] = textual_inversion_path.model_path
+
+ self.load_textual_inversion(
+ pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs
+ )
+
+ def auto_load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ r"""
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
+ loaded.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
+ loaded into `self.unet`.
+
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
+ dict is loaded into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if isinstance(pretrained_model_name_or_path_or_dict, str):
+ # Update kwargs to ensure the model is downloaded and parameters are included
+ _status = {
+ "download": True,
+ "include_params": True,
+ "skip_error": False,
+ "model_type": "LORA",
+ }
+ kwargs.update(_status)
+ # Search for the model on Civitai and get the model status
+ lora_path = search_civitai(pretrained_model_name_or_path_or_dict, **kwargs)
+ logger.warning(f"lora_path: {lora_path.model_status.site_url}")
+ logger.warning(f"trained_words: {lora_path.extra_status.trained_words}")
+ pretrained_model_name_or_path_or_dict = lora_path.model_path
+
+ self.load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs)
+
+
+class EasyPipelineForText2Image(AutoPipelineForText2Image):
+ r"""
+ [`EasyPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The
+ specific underlying pipeline class is automatically selected from either the
+ [`~EasyPipelineForText2Image.from_pretrained`], [`~EasyPipelineForText2Image.from_pipe`], [`~EasyPipelineForText2Image.from_huggingface`] or [`~EasyPipelineForText2Image.from_civitai`] methods.
+
+ This class cannot be instantiated using `__init__()` (throws an error).
+
+ Class attributes:
+
+ - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
+ diffusion pipeline's components.
+
+ """
+
+ config_name = "model_index.json"
+
+ def __init__(self, *args, **kwargs):
+ # EnvironmentError is returned
+ super().__init__()
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Parameters:
+ pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A keyword to search for Hugging Face (for example `Stable Diffusion`)
+ - Link to `.ckpt` or `.safetensors` file (for example
+ `"https://huggingface.co//blob/main/.safetensors"`) on the Hub.
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ checkpoint_format (`str`, *optional*, defaults to `"single_file"`):
+ The format of the model checkpoint.
+ pipeline_tag (`str`, *optional*):
+ Tag to filter models by pipeline.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ custom_revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
+ `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
+ custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ gated (`bool`, *optional*, defaults to `False` ):
+ A boolean to filter models on the Hub that are gated or not.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+ variant (`str`, *optional*):
+ Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from auto_diffusers import EasyPipelineForText2Image
+
+ >>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
+ >>> image = pipeline(prompt).images[0]
+ ```
+ """
+ # Update kwargs to ensure the model is downloaded and parameters are included
+ _status = {
+ "download": True,
+ "include_params": True,
+ "skip_error": False,
+ "pipeline_tag": "text-to-image",
+ }
+ kwargs.update(_status)
+
+ # Search for the model on Hugging Face and get the model status
+ hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
+ logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
+ checkpoint_path = hf_checkpoint_status.model_path
+
+ # Check the format of the model checkpoint
+ if hf_checkpoint_status.loading_method == "from_single_file":
+ # Load the pipeline from a single file checkpoint
+ pipeline = load_pipeline_from_single_file(
+ pretrained_model_or_path=checkpoint_path,
+ pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
+ **kwargs,
+ )
+ else:
+ pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
+ return add_methods(pipeline)
+
+ @classmethod
+ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Parameters:
+ pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A keyword to search for Hugging Face (for example `Stable Diffusion`)
+ - Link to `.ckpt` or `.safetensors` file (for example
+ `"https://huggingface.co//blob/main/.safetensors"`) on the Hub.
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ model_type (`str`, *optional*, defaults to `Checkpoint`):
+ The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`)
+ base_model (`str`, *optional*):
+ The base model to filter by.
+ cache_dir (`str`, `Path`, *optional*):
+ Path to the folder where cached files are stored.
+ resume (`bool`, *optional*, defaults to `False`):
+ Whether to resume an incomplete download.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str`, *optional*):
+ The token to use as HTTP bearer authorization for remote files.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from auto_diffusers import EasyPipelineForText2Image
+
+ >>> pipeline = EasyPipelineForText2Image.from_huggingface("stable-diffusion-v1-5")
+ >>> image = pipeline(prompt).images[0]
+ ```
+ """
+ # Update kwargs to ensure the model is downloaded and parameters are included
+ _status = {
+ "download": True,
+ "include_params": True,
+ "skip_error": False,
+ "model_type": "Checkpoint",
+ }
+ kwargs.update(_status)
+
+ # Search for the model on Civitai and get the model status
+ checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
+ logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
+ checkpoint_path = checkpoint_status.model_path
+
+ # Load the pipeline from a single file checkpoint
+ pipeline = load_pipeline_from_single_file(
+ pretrained_model_or_path=checkpoint_path,
+ pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,
+ **kwargs,
+ )
+ return add_methods(pipeline)
+
+
+class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
+ r"""
+
+ [`EasyPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The
+ specific underlying pipeline class is automatically selected from either the
+ [`~EasyPipelineForImage2Image.from_pretrained`], [`~EasyPipelineForImage2Image.from_pipe`], [`~EasyPipelineForImage2Image.from_huggingface`] or [`~EasyPipelineForImage2Image.from_civitai`] methods.
+
+ This class cannot be instantiated using `__init__()` (throws an error).
+
+ Class attributes:
+
+ - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
+ diffusion pipeline's components.
+
+ """
+
+ config_name = "model_index.json"
+
+ def __init__(self, *args, **kwargs):
+ # EnvironmentError is returned
+ super().__init__()
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Parameters:
+ pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A keyword to search for Hugging Face (for example `Stable Diffusion`)
+ - Link to `.ckpt` or `.safetensors` file (for example
+ `"https://huggingface.co//blob/main/.safetensors"`) on the Hub.
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ checkpoint_format (`str`, *optional*, defaults to `"single_file"`):
+ The format of the model checkpoint.
+ pipeline_tag (`str`, *optional*):
+ Tag to filter models by pipeline.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ custom_revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
+ `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
+ custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ gated (`bool`, *optional*, defaults to `False` ):
+ A boolean to filter models on the Hub that are gated or not.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+ variant (`str`, *optional*):
+ Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from auto_diffusers import EasyPipelineForImage2Image
+
+ >>> pipeline = EasyPipelineForImage2Image.from_huggingface("stable-diffusion-v1-5")
+ >>> image = pipeline(prompt, image).images[0]
+ ```
+ """
+ # Update kwargs to ensure the model is downloaded and parameters are included
+ _parmas = {
+ "download": True,
+ "include_params": True,
+ "skip_error": False,
+ "pipeline_tag": "image-to-image",
+ }
+ kwargs.update(_parmas)
+
+ # Search for the model on Hugging Face and get the model status
+ hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
+ logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
+ checkpoint_path = hf_checkpoint_status.model_path
+
+ # Check the format of the model checkpoint
+ if hf_checkpoint_status.loading_method == "from_single_file":
+ # Load the pipeline from a single file checkpoint
+ pipeline = load_pipeline_from_single_file(
+ pretrained_model_or_path=checkpoint_path,
+ pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,
+ **kwargs,
+ )
+ else:
+ pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
+
+ return add_methods(pipeline)
+
+ @classmethod
+ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Parameters:
+ pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A keyword to search for Hugging Face (for example `Stable Diffusion`)
+ - Link to `.ckpt` or `.safetensors` file (for example
+ `"https://huggingface.co//blob/main/.safetensors"`) on the Hub.
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ model_type (`str`, *optional*, defaults to `Checkpoint`):
+ The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`)
+ base_model (`str`, *optional*):
+ The base model to filter by.
+ cache_dir (`str`, `Path`, *optional*):
+ Path to the folder where cached files are stored.
+ resume (`bool`, *optional*, defaults to `False`):
+ Whether to resume an incomplete download.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str`, *optional*):
+ The token to use as HTTP bearer authorization for remote files.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from auto_diffusers import EasyPipelineForImage2Image
+
+ >>> pipeline = EasyPipelineForImage2Image.from_huggingface("stable-diffusion-v1-5")
+ >>> image = pipeline(prompt, image).images[0]
+ ```
+ """
+ # Update kwargs to ensure the model is downloaded and parameters are included
+ _status = {
+ "download": True,
+ "include_params": True,
+ "skip_error": False,
+ "model_type": "Checkpoint",
+ }
+ kwargs.update(_status)
+
+ # Search for the model on Civitai and get the model status
+ checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
+ logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
+ checkpoint_path = checkpoint_status.model_path
+
+ # Load the pipeline from a single file checkpoint
+ pipeline = load_pipeline_from_single_file(
+ pretrained_model_or_path=checkpoint_path,
+ pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,
+ **kwargs,
+ )
+ return add_methods(pipeline)
+
+
+class EasyPipelineForInpainting(AutoPipelineForInpainting):
+ r"""
+
+ [`EasyPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The
+ specific underlying pipeline class is automatically selected from either the
+ [`~EasyPipelineForInpainting.from_pretrained`], [`~EasyPipelineForInpainting.from_pipe`], [`~EasyPipelineForInpainting.from_huggingface`] or [`~EasyPipelineForInpainting.from_civitai`] methods.
+
+ This class cannot be instantiated using `__init__()` (throws an error).
+
+ Class attributes:
+
+ - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
+ diffusion pipeline's components.
+
+ """
+
+ config_name = "model_index.json"
+
+ def __init__(self, *args, **kwargs):
+ # EnvironmentError is returned
+ super().__init__()
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Parameters:
+ pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A keyword to search for Hugging Face (for example `Stable Diffusion`)
+ - Link to `.ckpt` or `.safetensors` file (for example
+ `"https://huggingface.co//blob/main/.safetensors"`) on the Hub.
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ checkpoint_format (`str`, *optional*, defaults to `"single_file"`):
+ The format of the model checkpoint.
+ pipeline_tag (`str`, *optional*):
+ Tag to filter models by pipeline.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ custom_revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
+ `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
+ custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ gated (`bool`, *optional*, defaults to `False` ):
+ A boolean to filter models on the Hub that are gated or not.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+ variant (`str`, *optional*):
+ Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from auto_diffusers import EasyPipelineForInpainting
+
+ >>> pipeline = EasyPipelineForInpainting.from_huggingface("stable-diffusion-2-inpainting")
+ >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+ """
+ # Update kwargs to ensure the model is downloaded and parameters are included
+ _status = {
+ "download": True,
+ "include_params": True,
+ "skip_error": False,
+ "pipeline_tag": "image-to-image",
+ }
+ kwargs.update(_status)
+
+ # Search for the model on Hugging Face and get the model status
+ hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)
+ logger.warning(f"checkpoint_path: {hf_checkpoint_status.model_status.download_url}")
+ checkpoint_path = hf_checkpoint_status.model_path
+
+ # Check the format of the model checkpoint
+ if hf_checkpoint_status.loading_method == "from_single_file":
+ # Load the pipeline from a single file checkpoint
+ pipeline = load_pipeline_from_single_file(
+ pretrained_model_or_path=checkpoint_path,
+ pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
+ **kwargs,
+ )
+ else:
+ pipeline = cls.from_pretrained(checkpoint_path, **kwargs)
+ return add_methods(pipeline)
+
+ @classmethod
+ def from_civitai(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Parameters:
+ pretrained_model_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A keyword to search for Hugging Face (for example `Stable Diffusion`)
+ - Link to `.ckpt` or `.safetensors` file (for example
+ `"https://huggingface.co//blob/main/.safetensors"`) on the Hub.
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
+ saved using
+ [`~DiffusionPipeline.save_pretrained`].
+ model_type (`str`, *optional*, defaults to `Checkpoint`):
+ The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`)
+ base_model (`str`, *optional*):
+ The base model to filter by.
+ cache_dir (`str`, `Path`, *optional*):
+ Path to the folder where cached files are stored.
+ resume (`bool`, *optional*, defaults to `False`):
+ Whether to resume an incomplete download.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str`, *optional*):
+ The token to use as HTTP bearer authorization for remote files.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn’t need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if device_map contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
+ below for more information.
+
+
+
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ `huggingface-cli login`.
+
+
+
+ Examples:
+
+ ```py
+ >>> from auto_diffusers import EasyPipelineForInpainting
+
+ >>> pipeline = EasyPipelineForInpainting.from_huggingface("stable-diffusion-2-inpainting")
+ >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+ """
+ # Update kwargs to ensure the model is downloaded and parameters are included
+ _status = {
+ "download": True,
+ "include_params": True,
+ "skip_error": False,
+ "model_type": "Checkpoint",
+ }
+ kwargs.update(_status)
+
+ # Search for the model on Civitai and get the model status
+ checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)
+ logger.warning(f"checkpoint_path: {checkpoint_status.model_status.site_url}")
+ checkpoint_path = checkpoint_status.model_path
+
+ # Load the pipeline from a single file checkpoint
+ pipeline = load_pipeline_from_single_file(
+ pretrained_model_or_path=checkpoint_path,
+ pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,
+ **kwargs,
+ )
+ return add_methods(pipeline)
diff --git a/examples/model_search/requirements.txt b/examples/model_search/requirements.txt
new file mode 100644
index 000000000000..db7bc19a3a2b
--- /dev/null
+++ b/examples/model_search/requirements.txt
@@ -0,0 +1 @@
+huggingface-hub>=0.26.2
diff --git a/examples/reinforcement_learning/README.md b/examples/reinforcement_learning/README.md
index 3c3ada2031cf..30d3b5bb1dd8 100644
--- a/examples/reinforcement_learning/README.md
+++ b/examples/reinforcement_learning/README.md
@@ -1,4 +1,13 @@
-# Overview
+
+## Diffusion-based Policy Learning for RL
+
+`diffusion_policy` implements [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/), a diffusion model that predicts robot action sequences in reinforcement learning tasks.
+
+This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow.
+
+To execute the script, run `diffusion_policy.py`
+
+## Diffuser Locomotion
These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
There are two ways to use the script, `run_diffuser_locomotion.py`.
diff --git a/examples/reinforcement_learning/diffusion_policy.py b/examples/reinforcement_learning/diffusion_policy.py
new file mode 100644
index 000000000000..3ef4c1dabc2e
--- /dev/null
+++ b/examples/reinforcement_learning/diffusion_policy.py
@@ -0,0 +1,201 @@
+import numpy as np
+import numpy.core.multiarray as multiarray
+import torch
+import torch.nn as nn
+from huggingface_hub import hf_hub_download
+from torch.serialization import add_safe_globals
+
+from diffusers import DDPMScheduler, UNet1DModel
+
+
+add_safe_globals(
+ [
+ multiarray._reconstruct,
+ np.ndarray,
+ np.dtype,
+ np.dtype(np.float32).type,
+ np.dtype(np.float64).type,
+ np.dtype(np.int32).type,
+ np.dtype(np.int64).type,
+ type(np.dtype(np.float32)),
+ type(np.dtype(np.float64)),
+ type(np.dtype(np.int32)),
+ type(np.dtype(np.int64)),
+ ]
+)
+
+"""
+An example of using HuggingFace's diffusers library for diffusion policy,
+generating smooth movement trajectories.
+
+This implements a robot control model for pushing a T-shaped block into a target area.
+The model takes in the robot arm position, block position, and block angle,
+then outputs a sequence of 16 (x,y) positions for the robot arm to follow.
+"""
+
+
+class ObservationEncoder(nn.Module):
+ """
+ Converts raw robot observations (positions/angles) into a more compact representation
+
+ state_dim (int): Dimension of the input state vector (default: 5)
+ [robot_x, robot_y, block_x, block_y, block_angle]
+
+ - Input shape: (batch_size, state_dim)
+ - Output shape: (batch_size, 256)
+ """
+
+ def __init__(self, state_dim):
+ super().__init__()
+ self.net = nn.Sequential(nn.Linear(state_dim, 512), nn.ReLU(), nn.Linear(512, 256))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class ObservationProjection(nn.Module):
+ """
+ Takes the encoded observation and transforms it into 32 values that represent the current robot/block situation.
+ These values are used as additional contextual information during the diffusion model's trajectory generation.
+
+ - Input: 256-dim vector (padded to 512)
+ Shape: (batch_size, 256)
+ - Output: 32 contextual information values for the diffusion model
+ Shape: (batch_size, 32)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(32, 512))
+ self.bias = nn.Parameter(torch.zeros(32))
+
+ def forward(self, x): # pad 256-dim input to 512-dim with zeros
+ if x.size(-1) == 256:
+ x = torch.cat([x, torch.zeros(*x.shape[:-1], 256, device=x.device)], dim=-1)
+ return nn.functional.linear(x, self.weight, self.bias)
+
+
+class DiffusionPolicy:
+ """
+ Implements diffusion policy for generating robot arm trajectories.
+ Uses diffusion to generate sequences of positions for a robot arm, conditioned on
+ the current state of the robot and the block it needs to push.
+
+ The model expects observations in pixel coordinates (0-512 range) and block angle in radians.
+ It generates trajectories as sequences of (x,y) coordinates also in the 0-512 range.
+ """
+
+ def __init__(self, state_dim=5, device="cpu"):
+ self.device = device
+
+ # define valid ranges for inputs/outputs
+ self.stats = {
+ "obs": {"min": torch.zeros(5), "max": torch.tensor([512, 512, 512, 512, 2 * np.pi])},
+ "action": {"min": torch.zeros(2), "max": torch.full((2,), 512)},
+ }
+
+ self.obs_encoder = ObservationEncoder(state_dim).to(device)
+ self.obs_projection = ObservationProjection().to(device)
+
+ # UNet model that performs the denoising process
+ # takes in concatenated action (2 channels) and context (32 channels) = 34 channels
+ # outputs predicted action (2 channels for x,y coordinates)
+ self.model = UNet1DModel(
+ sample_size=16, # length of trajectory sequence
+ in_channels=34,
+ out_channels=2,
+ layers_per_block=2, # number of layers per each UNet block
+ block_out_channels=(128,), # number of output neurons per layer in each block
+ down_block_types=("DownBlock1D",), # reduce the resolution of data
+ up_block_types=("UpBlock1D",), # increase the resolution of data
+ ).to(device)
+
+ # noise scheduler that controls the denoising process
+ self.noise_scheduler = DDPMScheduler(
+ num_train_timesteps=100, # number of denoising steps
+ beta_schedule="squaredcos_cap_v2", # type of noise schedule
+ )
+
+ # load pre-trained weights from HuggingFace
+ checkpoint = torch.load(
+ hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device
+ )
+ self.model.load_state_dict(checkpoint["model_state_dict"])
+ self.obs_encoder.load_state_dict(checkpoint["encoder_state_dict"])
+ self.obs_projection.load_state_dict(checkpoint["projection_state_dict"])
+
+ # scales data to [-1, 1] range for neural network processing
+ def normalize_data(self, data, stats):
+ return ((data - stats["min"]) / (stats["max"] - stats["min"])) * 2 - 1
+
+ # converts normalized data back to original range
+ def unnormalize_data(self, ndata, stats):
+ return ((ndata + 1) / 2) * (stats["max"] - stats["min"]) + stats["min"]
+
+ @torch.no_grad()
+ def predict(self, observation):
+ """
+ Generates a trajectory of robot arm positions given the current state.
+
+ Args:
+ observation (torch.Tensor): Current state [robot_x, robot_y, block_x, block_y, block_angle]
+ Shape: (batch_size, 5)
+
+ Returns:
+ torch.Tensor: Sequence of (x,y) positions for the robot arm to follow
+ Shape: (batch_size, 16, 2) where:
+ - 16 is the number of steps in the trajectory
+ - 2 is the (x,y) coordinates in pixel space (0-512)
+
+ The function first encodes the observation, then uses it to condition a diffusion
+ process that gradually denoises random trajectories into smooth, purposeful movements.
+ """
+ observation = observation.to(self.device)
+ normalized_obs = self.normalize_data(observation, self.stats["obs"])
+
+ # encode the observation into context values for the diffusion model
+ cond = self.obs_projection(self.obs_encoder(normalized_obs))
+ # keeps first & second dimension sizes unchanged, and multiplies last dimension by 16
+ cond = cond.view(normalized_obs.shape[0], -1, 1).expand(-1, -1, 16)
+
+ # initialize action with noise - random noise that will be refined into a trajectory
+ action = torch.randn((observation.shape[0], 2, 16), device=self.device)
+
+ # denoise
+ # at each step `t`, the current noisy trajectory (`action`) & conditioning info (context) are
+ # fed into the model to predict a denoised trajectory, then uses self.noise_scheduler.step to
+ # apply this prediction & slightly reduce the noise in `action` more
+
+ self.noise_scheduler.set_timesteps(100)
+ for t in self.noise_scheduler.timesteps:
+ model_output = self.model(torch.cat([action, cond], dim=1), t)
+ action = self.noise_scheduler.step(model_output.sample, t, action).prev_sample
+
+ action = action.transpose(1, 2) # reshape to [batch, 16, 2]
+ action = self.unnormalize_data(action, self.stats["action"]) # scale back to coordinates
+ return action
+
+
+if __name__ == "__main__":
+ policy = DiffusionPolicy()
+
+ # sample of a single observation
+ # robot arm starts in center, block is slightly left and up, rotated 90 degrees
+ obs = torch.tensor(
+ [
+ [
+ 256.0, # robot arm x position (middle of screen)
+ 256.0, # robot arm y position (middle of screen)
+ 200.0, # block x position
+ 300.0, # block y position
+ np.pi / 2, # block angle (90 degrees)
+ ]
+ ]
+ )
+
+ action = policy.predict(obs)
+
+ print("Action shape:", action.shape) # should be [1, 16, 2] - one trajectory of 16 x,y positions
+ print("\nPredicted trajectory:")
+ for i, (x, y) in enumerate(action[0]):
+ print(f"Step {i:2d}: x={x:6.1f}, y={y:6.1f}")
diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md
new file mode 100644
index 000000000000..3a67efd8b2f4
--- /dev/null
+++ b/examples/research_projects/anytext/README.md
@@ -0,0 +1,40 @@
+# AnyTextPipeline
+
+Project page: https://aigcdesigngroup.github.io/homepage_anytext
+
+"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
+
+> **Note:** Each text line that needs to be generated should be enclosed in double quotes.
+
+For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
+
+[](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)
+
+```py
+# This example requires the `anytext_controlnet.py` file:
+# !git clone --depth 1 https://github.com/huggingface/diffusers.git
+# %cd diffusers/examples/research_projects/anytext
+# Let's choose a font file shared by an HF staff:
+# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
+
+import torch
+from diffusers import DiffusionPipeline
+from anytext_controlnet import AnyTextControlNetModel
+from diffusers.utils import load_image
+
+
+anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
+ variant="fp16",)
+pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
+ controlnet=anytext_controlnet, torch_dtype=torch.float16,
+ trust_remote_code=False, # One needs to give permission to run this pipeline's code
+ ).to("cuda")
+
+# generate image
+prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
+draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
+# There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
+image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
+ ).images[0]
+image
+```
diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py
new file mode 100644
index 000000000000..5c30b24efe88
--- /dev/null
+++ b/examples/research_projects/anytext/anytext.py
@@ -0,0 +1,2366 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright (c) Alibaba, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).
+# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie
+# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license
+#
+# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
+
+
+import inspect
+import math
+import os
+import re
+import sys
+import unicodedata
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from huggingface_hub import hf_hub_download
+from ocr_recog.RecModel import RecModel
+from PIL import Image, ImageDraw, ImageFont
+from safetensors.torch import load_file
+from skimage.transform._geometric import _umeyama as get_sym_mat
+from torch import nn
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.constants import HF_MODULES_CACHE
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+
+
+class Checker:
+ def __init__(self):
+ pass
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF)
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
+ ):
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or self._is_control(char):
+ continue
+ if self._is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_control(self, char):
+ """Checks whether `chars` is a control character."""
+ # These are technically control characters but we count them as whitespace
+ # characters.
+ if char == "\t" or char == "\n" or char == "\r":
+ return False
+ cat = unicodedata.category(char)
+ if cat in ("Cc", "Cf"):
+ return True
+ return False
+
+ def _is_whitespace(self, char):
+ """Checks whether `chars` is a whitespace character."""
+ # \t, \n, and \r are technically control characters but we treat them
+ # as whitespace since they are generally considered as such.
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+
+checker = Checker()
+
+
+PLACE_HOLDER = "*"
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # This example requires the `anytext_controlnet.py` file:
+ >>> # !git clone --depth 1 https://github.com/huggingface/diffusers.git
+ >>> # %cd diffusers/examples/research_projects/anytext
+ >>> # Let's choose a font file shared by an HF staff:
+ >>> # !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
+
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+ >>> from anytext_controlnet import AnyTextControlNetModel
+ >>> from diffusers.utils import load_image
+
+ >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
+ ... variant="fp16",)
+ >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
+ ... controlnet=anytext_controlnet, torch_dtype=torch.float16,
+ ... trust_remote_code=False, # One needs to give permission to run this pipeline's code
+ ... ).to("cuda")
+
+
+ >>> # generate image
+ >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
+ >>> draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
+ >>> # There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
+ >>> image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
+ ... ).images[0]
+ >>> image
+ ```
+"""
+
+
+def get_clip_token_for_string(tokenizer, string):
+ batch_encoding = tokenizer(
+ string,
+ truncation=True,
+ max_length=77,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"]
+ assert (
+ torch.count_nonzero(tokens - 49407) == 2
+ ), f"String '{string}' maps to more than a single token. Please use another string"
+ return tokens[0, 1]
+
+
+def get_recog_emb(encoder, img_list):
+ _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list]
+ encoder.predictor.eval()
+ _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
+ return preds_neck
+
+
+class EmbeddingManager(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ embedder,
+ placeholder_string="*",
+ use_fp16=False,
+ token_dim=768,
+ get_recog_emb=None,
+ ):
+ super().__init__()
+ get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
+
+ self.proj = nn.Linear(40 * 64, token_dim)
+ proj_dir = hf_hub_download(
+ repo_id="tolgacangoz/anytext",
+ filename="text_embedding_module/proj.safetensors",
+ cache_dir=HF_MODULES_CACHE,
+ )
+ self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
+ if use_fp16:
+ self.proj = self.proj.to(dtype=torch.float16)
+
+ self.placeholder_token = get_token_for_string(placeholder_string)
+
+ @torch.no_grad()
+ def encode_text(self, text_info):
+ if self.config.get_recog_emb is None:
+ self.config.get_recog_emb = partial(get_recog_emb, self.recog)
+
+ gline_list = []
+ for i in range(len(text_info["n_lines"])): # sample index in a batch
+ n_lines = text_info["n_lines"][i]
+ for j in range(n_lines): # line
+ gline_list += [text_info["gly_line"][j][i : i + 1]]
+
+ if len(gline_list) > 0:
+ recog_emb = self.config.get_recog_emb(gline_list)
+ enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype))
+
+ self.text_embs_all = []
+ n_idx = 0
+ for i in range(len(text_info["n_lines"])): # sample index in a batch
+ n_lines = text_info["n_lines"][i]
+ text_embs = []
+ for j in range(n_lines): # line
+ text_embs += [enc_glyph[n_idx : n_idx + 1]]
+ n_idx += 1
+ self.text_embs_all += [text_embs]
+
+ @torch.no_grad()
+ def forward(
+ self,
+ tokenized_text,
+ embedded_text,
+ ):
+ b, device = tokenized_text.shape[0], tokenized_text.device
+ for i in range(b):
+ idx = tokenized_text[i] == self.placeholder_token.to(device)
+ if sum(idx) > 0:
+ if i >= len(self.text_embs_all):
+ logger.warning("truncation for log images...")
+ break
+ text_emb = torch.cat(self.text_embs_all[i], dim=0)
+ if sum(idx) != len(text_emb):
+ logger.warning("truncation for long caption...")
+ text_emb = text_emb.to(embedded_text.device)
+ embedded_text[i][idx] = text_emb[: sum(idx)]
+ return embedded_text
+
+ def embedding_parameters(self):
+ return self.parameters()
+
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+
+def min_bounding_rect(img):
+ ret, thresh = cv2.threshold(img, 127, 255, 0)
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ if len(contours) == 0:
+ print("Bad contours, using fake bbox...")
+ return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
+ max_contour = max(contours, key=cv2.contourArea)
+ rect = cv2.minAreaRect(max_contour)
+ box = cv2.boxPoints(rect)
+ box = np.int0(box)
+ # sort
+ x_sorted = sorted(box, key=lambda x: x[0])
+ left = x_sorted[:2]
+ right = x_sorted[2:]
+ left = sorted(left, key=lambda x: x[1])
+ (tl, bl) = left
+ right = sorted(right, key=lambda x: x[1])
+ (tr, br) = right
+ if tl[1] > bl[1]:
+ (tl, bl) = (bl, tl)
+ if tr[1] > br[1]:
+ (tr, br) = (br, tr)
+ return np.array([tl, tr, br, bl])
+
+
+def adjust_image(box, img):
+ pts1 = np.float32([box[0], box[1], box[2], box[3]])
+ width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3]))
+ height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2]))
+ pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]])
+ # get transform matrix
+ M = get_sym_mat(pts1, pts2, estimate_scale=True)
+ C, H, W = img.shape
+ T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]])
+ theta = np.linalg.inv(T @ M @ np.linalg.inv(T))
+ theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device)
+ grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True)
+ result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True)
+ result = torch.clamp(result.squeeze(0), 0, 255)
+ # crop
+ result = result[:, : int(height), : int(width)]
+ return result
+
+
+def crop_image(src_img, mask):
+ box = min_bounding_rect(mask)
+ result = adjust_image(box, src_img)
+ if len(result.shape) == 2:
+ result = torch.stack([result] * 3, axis=-1)
+ return result
+
+
+def create_predictor(model_lang="ch", device="cpu", use_fp16=False):
+ model_dir = hf_hub_download(
+ repo_id="tolgacangoz/anytext",
+ filename="text_embedding_module/OCR/ppv3_rec.pth",
+ cache_dir=HF_MODULES_CACHE,
+ )
+ if not os.path.exists(model_dir):
+ raise ValueError("not find model file path {}".format(model_dir))
+
+ if model_lang == "ch":
+ n_class = 6625
+ elif model_lang == "en":
+ n_class = 97
+ else:
+ raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
+ rec_config = {
+ "in_channels": 3,
+ "backbone": {"type": "MobileNetV1Enhance", "scale": 0.5, "last_conv_stride": [1, 2], "last_pool_type": "avg"},
+ "neck": {
+ "type": "SequenceEncoder",
+ "encoder_type": "svtr",
+ "dims": 64,
+ "depth": 2,
+ "hidden_dims": 120,
+ "use_guide": True,
+ },
+ "head": {"type": "CTCHead", "fc_decay": 0.00001, "out_channels": n_class, "return_feats": True},
+ }
+
+ rec_model = RecModel(rec_config)
+ state_dict = torch.load(model_dir, map_location=device)
+ rec_model.load_state_dict(state_dict)
+ return rec_model
+
+
+def _check_image_file(path):
+ img_end = ("tiff", "tif", "bmp", "rgb", "jpg", "png", "jpeg")
+ return path.lower().endswith(tuple(img_end))
+
+
+def get_image_file_list(img_file):
+ imgs_lists = []
+ if img_file is None or not os.path.exists(img_file):
+ raise Exception("not found any img file in {}".format(img_file))
+ if os.path.isfile(img_file) and _check_image_file(img_file):
+ imgs_lists.append(img_file)
+ elif os.path.isdir(img_file):
+ for single_file in os.listdir(img_file):
+ file_path = os.path.join(img_file, single_file)
+ if os.path.isfile(file_path) and _check_image_file(file_path):
+ imgs_lists.append(file_path)
+ if len(imgs_lists) == 0:
+ raise Exception("not found any img file in {}".format(img_file))
+ imgs_lists = sorted(imgs_lists)
+ return imgs_lists
+
+
+class TextRecognizer(object):
+ def __init__(self, args, predictor):
+ self.rec_image_shape = [int(v) for v in args["rec_image_shape"].split(",")]
+ self.rec_batch_num = args["rec_batch_num"]
+ self.predictor = predictor
+ self.chars = self.get_char_dict(args["rec_char_dict_path"])
+ self.char2id = {x: i for i, x in enumerate(self.chars)}
+ self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
+ self.use_fp16 = args["use_fp16"]
+
+ # img: CHW
+ def resize_norm_img(self, img, max_wh_ratio):
+ imgC, imgH, imgW = self.rec_image_shape
+ assert imgC == img.shape[0]
+ imgW = int((imgH * max_wh_ratio))
+
+ h, w = img.shape[1:]
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = torch.nn.functional.interpolate(
+ img.unsqueeze(0),
+ size=(imgH, resized_w),
+ mode="bilinear",
+ align_corners=True,
+ )
+ resized_image /= 255.0
+ resized_image -= 0.5
+ resized_image /= 0.5
+ padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
+ padding_im[:, :, 0:resized_w] = resized_image[0]
+ return padding_im
+
+ # img_list: list of tensors with shape chw 0-255
+ def pred_imglist(self, img_list, show_debug=False):
+ img_num = len(img_list)
+ assert img_num > 0
+ # Calculate the aspect ratio of all text bars
+ width_list = []
+ for img in img_list:
+ width_list.append(img.shape[2] / float(img.shape[1]))
+ # Sorting can speed up the recognition process
+ indices = torch.from_numpy(np.argsort(np.array(width_list)))
+ batch_num = self.rec_batch_num
+ preds_all = [None] * img_num
+ preds_neck_all = [None] * img_num
+ for beg_img_no in range(0, img_num, batch_num):
+ end_img_no = min(img_num, beg_img_no + batch_num)
+ norm_img_batch = []
+
+ imgC, imgH, imgW = self.rec_image_shape[:3]
+ max_wh_ratio = imgW / imgH
+ for ino in range(beg_img_no, end_img_no):
+ h, w = img_list[indices[ino]].shape[1:]
+ if h > w * 1.2:
+ img = img_list[indices[ino]]
+ img = torch.transpose(img, 1, 2).flip(dims=[1])
+ img_list[indices[ino]] = img
+ h, w = img.shape[1:]
+ # wh_ratio = w * 1.0 / h
+ # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
+ for ino in range(beg_img_no, end_img_no):
+ norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
+ if self.use_fp16:
+ norm_img = norm_img.half()
+ norm_img = norm_img.unsqueeze(0)
+ norm_img_batch.append(norm_img)
+ norm_img_batch = torch.cat(norm_img_batch, dim=0)
+ if show_debug:
+ for i in range(len(norm_img_batch)):
+ _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
+ _img = (_img + 0.5) * 255
+ _img = _img[:, :, ::-1]
+ file_name = f"{indices[beg_img_no + i]}"
+ if os.path.exists(file_name + ".jpg"):
+ file_name += "_2" # ori image
+ cv2.imwrite(file_name + ".jpg", _img)
+ if self.is_onnx:
+ input_dict = {}
+ input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy()
+ outputs = self.predictor.run(None, input_dict)
+ preds = {}
+ preds["ctc"] = torch.from_numpy(outputs[0])
+ preds["ctc_neck"] = [torch.zeros(1)] * img_num
+ else:
+ preds = self.predictor(norm_img_batch.to(next(self.predictor.parameters()).device))
+ for rno in range(preds["ctc"].shape[0]):
+ preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
+ preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
+
+ return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
+
+ def get_char_dict(self, character_dict_path):
+ character_str = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode("utf-8").strip("\n").strip("\r\n")
+ character_str.append(line)
+ dict_character = list(character_str)
+ dict_character = ["sos"] + dict_character + [" "] # eos is space
+ return dict_character
+
+ def get_text(self, order):
+ char_list = [self.chars[text_id] for text_id in order]
+ return "".join(char_list)
+
+ def decode(self, mat):
+ text_index = mat.detach().cpu().numpy().argmax(axis=1)
+ ignored_tokens = [0]
+ selection = np.ones(len(text_index), dtype=bool)
+ selection[1:] = text_index[1:] != text_index[:-1]
+ for ignored_token in ignored_tokens:
+ selection &= text_index != ignored_token
+ return text_index[selection], np.where(selection)[0]
+
+ def get_ctcloss(self, preds, gt_text, weight):
+ if not isinstance(weight, torch.Tensor):
+ weight = torch.tensor(weight).to(preds.device)
+ ctc_loss = torch.nn.CTCLoss(reduction="none")
+ log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
+ targets = []
+ target_lengths = []
+ for t in gt_text:
+ targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
+ target_lengths += [len(t)]
+ targets = torch.tensor(targets).to(preds.device)
+ target_lengths = torch.tensor(target_lengths).to(preds.device)
+ input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device)
+ loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
+ loss = loss / input_lengths * weight
+ return loss
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+
+ @register_to_config
+ def __init__(
+ self,
+ device="cpu",
+ max_length=77,
+ freeze=True,
+ use_fp16=False,
+ variant: Optional[str] = None,
+ ):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
+ self.transformer = CLIPTextModel.from_pretrained(
+ "tolgacangoz/anytext",
+ subfolder="text_encoder",
+ torch_dtype=torch.float16 if use_fp16 else torch.float32,
+ variant="fp16" if use_fp16 else None,
+ )
+
+ if freeze:
+ self.freeze()
+
+ def embedding_forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ embedding_manager=None,
+ ):
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+ if embedding_manager is not None:
+ inputs_embeds = embedding_manager(input_ids, inputs_embeds)
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+ return embeddings
+
+ self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
+ self.transformer.text_model.embeddings
+ )
+
+ def encoder_forward(
+ self,
+ inputs_embeds,
+ attention_mask=None,
+ causal_attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ return hidden_states
+
+ self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
+
+ def text_encoder_forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ embedding_manager=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if input_ids is None:
+ raise ValueError("You have to specify either input_ids")
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ hidden_states = self.embeddings(
+ input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager
+ )
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = _create_4d_causal_attention_mask(
+ input_shape, hidden_states.dtype, device=hidden_states.device
+ )
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+ last_hidden_state = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+ return last_hidden_state
+
+ self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
+
+ def transformer_forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ embedding_manager=None,
+ ):
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ embedding_manager=embedding_manager,
+ )
+
+ self.transformer.forward = transformer_forward.__get__(self.transformer)
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text, **kwargs):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=False,
+ max_length=self.config.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="longest",
+ return_tensors="pt",
+ )
+ input_ids = batch_encoding["input_ids"]
+ tokens_list = self.split_chunks(input_ids)
+ z_list = []
+ for tokens in tokens_list:
+ tokens = tokens.to(self.device)
+ _z = self.transformer(input_ids=tokens, **kwargs)
+ z_list += [_z]
+ return torch.cat(z_list, dim=1)
+
+ def encode(self, text, **kwargs):
+ return self(text, **kwargs)
+
+ def split_chunks(self, input_ids, chunk_size=75):
+ tokens_list = []
+ bs, n = input_ids.shape
+ id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1]
+ id_end = input_ids[:, -1].unsqueeze(1)
+ if n == 2: # empty caption
+ tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
+
+ trimmed_encoding = input_ids[:, 1:-1]
+ num_full_groups = (n - 2) // chunk_size
+
+ for i in range(num_full_groups):
+ group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size]
+ group_pad = torch.cat((id_start, group, id_end), dim=1)
+ tokens_list.append(group_pad)
+
+ remaining_columns = (n - 2) % chunk_size
+ if remaining_columns > 0:
+ remaining_group = trimmed_encoding[:, -remaining_columns:]
+ padding_columns = chunk_size - remaining_group.shape[1]
+ padding = id_end.expand(bs, padding_columns)
+ remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1)
+ tokens_list.append(remaining_group_pad)
+ return tokens_list
+
+
+class TextEmbeddingModule(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(self, font_path, use_fp16=False, device="cpu"):
+ super().__init__()
+ font = ImageFont.truetype(font_path, 60)
+
+ self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
+ self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
+ self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
+ args = {
+ "rec_image_shape": "3, 48, 320",
+ "rec_batch_num": 6,
+ "rec_char_dict_path": hf_hub_download(
+ repo_id="tolgacangoz/anytext",
+ filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
+ cache_dir=HF_MODULES_CACHE,
+ ),
+ "use_fp16": use_fp16,
+ }
+ self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
+
+ self.register_to_config(font=font)
+
+ @torch.no_grad()
+ def forward(
+ self,
+ prompt,
+ texts,
+ negative_prompt,
+ num_images_per_prompt,
+ mode,
+ draw_pos,
+ sort_priority="↕",
+ max_chars=77,
+ revise_pos=False,
+ h=512,
+ w=512,
+ ):
+ if prompt is None and texts is None:
+ raise ValueError("Prompt or texts must be provided!")
+ # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
+ if draw_pos is None:
+ pos_imgs = np.zeros((w, h, 1))
+ if isinstance(draw_pos, PIL.Image.Image):
+ pos_imgs = np.array(draw_pos)[..., ::-1]
+ pos_imgs = 255 - pos_imgs
+ elif isinstance(draw_pos, str):
+ draw_pos = cv2.imread(draw_pos)[..., ::-1]
+ if draw_pos is None:
+ raise ValueError(f"Can't read draw_pos image from {draw_pos}!")
+ pos_imgs = 255 - draw_pos
+ elif isinstance(draw_pos, torch.Tensor):
+ pos_imgs = draw_pos.cpu().numpy()
+ else:
+ if not isinstance(draw_pos, np.ndarray):
+ raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}")
+ if mode == "edit":
+ pos_imgs = cv2.resize(pos_imgs, (w, h))
+ pos_imgs = pos_imgs[..., 0:1]
+ pos_imgs = cv2.convertScaleAbs(pos_imgs)
+ _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
+ # separate pos_imgs
+ pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)
+ if len(pos_imgs) == 0:
+ pos_imgs = [np.zeros((h, w, 1))]
+ n_lines = len(texts)
+ if len(pos_imgs) < n_lines:
+ if n_lines == 1 and texts[0] == " ":
+ pass # text-to-image without text
+ else:
+ raise ValueError(
+ f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!"
+ )
+ elif len(pos_imgs) > n_lines:
+ str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
+ logger.warning(str_warning)
+ # get pre_pos, poly_list, hint that needed for anytext
+ pre_pos = []
+ poly_list = []
+ for input_pos in pos_imgs:
+ if input_pos.mean() != 0:
+ input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos
+ poly, pos_img = self.find_polygon(input_pos)
+ pre_pos += [pos_img / 255.0]
+ poly_list += [poly]
+ else:
+ pre_pos += [np.zeros((h, w, 1))]
+ poly_list += [None]
+ np_hint = np.sum(pre_pos, axis=0).clip(0, 1)
+ # prepare info dict
+ text_info = {}
+ text_info["glyphs"] = []
+ text_info["gly_line"] = []
+ text_info["positions"] = []
+ text_info["n_lines"] = [len(texts)] * num_images_per_prompt
+ for i in range(len(texts)):
+ text = texts[i]
+ if len(text) > max_chars:
+ str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...'
+ logger.warning(str_warning)
+ text = text[:max_chars]
+ gly_scale = 2
+ if pre_pos[i].mean() != 0:
+ gly_line = self.draw_glyph(self.config.font, text)
+ glyphs = self.draw_glyph2(
+ self.config.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
+ )
+ if revise_pos:
+ resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
+ new_pos = cv2.morphologyEx(
+ (resize_gly * 255).astype(np.uint8),
+ cv2.MORPH_CLOSE,
+ kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8),
+ iterations=1,
+ )
+ new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos
+ contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
+ if len(contours) != 1:
+ str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..."
+ logger.warning(str_warning)
+ else:
+ rect = cv2.minAreaRect(contours[0])
+ poly = np.int0(cv2.boxPoints(rect))
+ pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0
+ else:
+ glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
+ gly_line = np.zeros((80, 512, 1))
+ pos = pre_pos[i]
+ text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)]
+ text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)]
+ text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)]
+
+ self.embedding_manager.encode_text(text_info)
+ prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
+
+ self.embedding_manager.encode_text(text_info)
+ negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode(
+ [negative_prompt or ""], embedding_manager=self.embedding_manager
+ )
+
+ return prompt_embeds, negative_prompt_embeds, text_info, np_hint
+
+ def arr2tensor(self, arr, bs):
+ arr = np.transpose(arr, (2, 0, 1))
+ _arr = torch.from_numpy(arr.copy()).float().cpu()
+ if self.config.use_fp16:
+ _arr = _arr.half()
+ _arr = torch.stack([_arr for _ in range(bs)], dim=0)
+ return _arr
+
+ def separate_pos_imgs(self, img, sort_priority, gap=102):
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)
+ components = []
+ for label in range(1, num_labels):
+ component = np.zeros_like(img)
+ component[labels == label] = 255
+ components.append((component, centroids[label]))
+ if sort_priority == "↕":
+ fir, sec = 1, 0 # top-down first
+ elif sort_priority == "↔":
+ fir, sec = 0, 1 # left-right first
+ else:
+ raise ValueError(f"Unknown sort_priority: {sort_priority}")
+ components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
+ sorted_components = [c[0] for c in components]
+ return sorted_components
+
+ def find_polygon(self, image, min_rect=False):
+ contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
+ max_contour = max(contours, key=cv2.contourArea) # get contour with max area
+ if min_rect:
+ # get minimum enclosing rectangle
+ rect = cv2.minAreaRect(max_contour)
+ poly = np.int0(cv2.boxPoints(rect))
+ else:
+ # get approximate polygon
+ epsilon = 0.01 * cv2.arcLength(max_contour, True)
+ poly = cv2.approxPolyDP(max_contour, epsilon, True)
+ n, _, xy = poly.shape
+ poly = poly.reshape(n, xy)
+ cv2.drawContours(image, [poly], -1, 255, -1)
+ return poly, image
+
+ def draw_glyph(self, font, text):
+ g_size = 50
+ W, H = (512, 80)
+ new_font = font.font_variant(size=g_size)
+ img = Image.new(mode="1", size=(W, H), color=0)
+ draw = ImageDraw.Draw(img)
+ left, top, right, bottom = new_font.getbbox(text)
+ text_width = max(right - left, 5)
+ text_height = max(bottom - top, 5)
+ ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)
+ new_font = font.font_variant(size=int(g_size * ratio))
+
+ left, top, right, bottom = new_font.getbbox(text)
+ text_width = right - left
+ text_height = bottom - top
+ x = (img.width - text_width) // 2
+ y = (img.height - text_height) // 2 - top // 2
+ draw.text((x, y), text, font=new_font, fill="white")
+ img = np.expand_dims(np.array(img), axis=2).astype(np.float64)
+ return img
+
+ def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True):
+ enlarge_polygon = polygon * scale
+ rect = cv2.minAreaRect(enlarge_polygon)
+ box = cv2.boxPoints(rect)
+ box = np.int0(box)
+ w, h = rect[1]
+ angle = rect[2]
+ if angle < -45:
+ angle += 90
+ angle = -angle
+ if w < h:
+ angle += 90
+
+ vert = False
+ if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:
+ _w = max(box[:, 0]) - min(box[:, 0])
+ _h = max(box[:, 1]) - min(box[:, 1])
+ if _h >= _w:
+ vert = True
+ angle = 0
+
+ img = np.zeros((height * scale, width * scale, 3), np.uint8)
+ img = Image.fromarray(img)
+
+ # infer font size
+ image4ratio = Image.new("RGB", img.size, "white")
+ draw = ImageDraw.Draw(image4ratio)
+ _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
+ text_w = min(w, h) * (_tw / _th)
+ if text_w <= max(w, h):
+ # add space
+ if len(text) > 1 and not vert and add_space:
+ for i in range(1, 100):
+ text_space = self.insert_spaces(text, i)
+ _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)
+ if min(w, h) * (_tw2 / _th2) > max(w, h):
+ break
+ text = self.insert_spaces(text, i - 1)
+ font_size = min(w, h) * 0.80
+ else:
+ shrink = 0.75 if vert else 0.85
+ font_size = min(w, h) / (text_w / max(w, h)) * shrink
+ new_font = font.font_variant(size=int(font_size))
+
+ left, top, right, bottom = new_font.getbbox(text)
+ text_width = right - left
+ text_height = bottom - top
+
+ layer = Image.new("RGBA", img.size, (0, 0, 0, 0))
+ draw = ImageDraw.Draw(layer)
+ if not vert:
+ draw.text(
+ (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),
+ text,
+ font=new_font,
+ fill=(255, 255, 255, 255),
+ )
+ else:
+ x_s = min(box[:, 0]) + _w // 2 - text_height // 2
+ y_s = min(box[:, 1])
+ for c in text:
+ draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))
+ _, _t, _, _b = new_font.getbbox(c)
+ y_s += _b
+
+ rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))
+
+ x_offset = int((img.width - rotated_layer.width) / 2)
+ y_offset = int((img.height - rotated_layer.height) / 2)
+ img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
+ img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
+ return img
+
+ def insert_spaces(self, string, nSpace):
+ if nSpace == 0:
+ return string
+ new_string = ""
+ for char in string:
+ new_string += char + " " * nSpace
+ return new_string[:-nSpace]
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ vae,
+ device="cpu",
+ ):
+ super().__init__()
+
+ @torch.no_grad()
+ def forward(
+ self,
+ text_info,
+ mode,
+ draw_pos,
+ ori_image,
+ num_images_per_prompt,
+ np_hint,
+ h=512,
+ w=512,
+ ):
+ if mode == "generate":
+ edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
+ elif mode == "edit":
+ if draw_pos is None or ori_image is None:
+ raise ValueError("Reference image and position image are needed for text editing!")
+ if isinstance(ori_image, str):
+ ori_image = cv2.imread(ori_image)[..., ::-1]
+ if ori_image is None:
+ raise ValueError(f"Can't read ori_image image from {ori_image}!")
+ elif isinstance(ori_image, torch.Tensor):
+ ori_image = ori_image.cpu().numpy()
+ elif isinstance(ori_image, PIL.Image.Image):
+ ori_image = np.array(ori_image.convert("RGB"))
+ else:
+ if not isinstance(ori_image, np.ndarray):
+ raise ValueError(f"Unknown format of ori_image: {type(ori_image)}")
+ edit_image = ori_image.clip(1, 255) # for mask reason
+ edit_image = self.check_channels(edit_image)
+ edit_image = self.resize_image(
+ edit_image, max_length=768
+ ) # make w h multiple of 64, resize if w or h > max_length
+
+ # get masked_x
+ masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
+ masked_img = np.transpose(masked_img, (2, 0, 1))
+ device = next(self.config.vae.parameters()).device
+ dtype = next(self.config.vae.parameters()).dtype
+ masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
+ if dtype == torch.float16:
+ masked_img = masked_img.half()
+ masked_x = (
+ retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor
+ ).detach()
+ if dtype == torch.float16:
+ masked_x = masked_x.half()
+ text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
+
+ glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
+ positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
+
+ return glyphs, positions, text_info
+
+ def check_channels(self, image):
+ channels = image.shape[2] if len(image.shape) == 3 else 1
+ if channels == 1:
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
+ elif channels > 3:
+ image = image[:, :, :3]
+ return image
+
+ def resize_image(self, img, max_length=768):
+ height, width = img.shape[:2]
+ max_dimension = max(height, width)
+
+ if max_dimension > max_length:
+ scale_factor = max_length / max_dimension
+ new_width = int(round(width * scale_factor))
+ new_height = int(round(height * scale_factor))
+ new_size = (new_width, new_height)
+ img = cv2.resize(img, new_size)
+ height, width = img.shape[:2]
+ img = cv2.resize(img, (width - (width % 64), height - (height % 64)))
+ return img
+
+ def insert_spaces(self, string, nSpace):
+ if nSpace == 0:
+ return string
+ new_string = ""
+ for char in string:
+ new_string += char + " " * nSpace
+ return new_string[:-nSpace]
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class AnyTextPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
+ additional conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ font_path: str = None,
+ text_embedding_module: Optional[TextEmbeddingModule] = None,
+ auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None,
+ trust_remote_code: bool = False,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+ if font_path is None:
+ raise ValueError("font_path is required!")
+
+ text_embedding_module = TextEmbeddingModule(font_path=font_path, use_fp16=unet.dtype == torch.float16)
+ auxiliary_latent_module = AuxiliaryLatentModule(vae=vae)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ text_embedding_module=text_embedding_module,
+ auxiliary_latent_module=auxiliary_latent_module,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def modify_prompt(self, prompt):
+ prompt = prompt.replace("“", '"')
+ prompt = prompt.replace("”", '"')
+ p = '"(.*?)"'
+ strs = re.findall(p, prompt)
+ if len(strs) == 0:
+ strs = [" "]
+ else:
+ for s in strs:
+ prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1)
+ if self.is_chinese(prompt):
+ if self.trans_pipe is None:
+ return None, None
+ old_prompt = prompt
+ prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1]
+ print(f"Translate: {old_prompt} --> {prompt}")
+ return prompt, strs
+
+ def is_chinese(self, text):
+ text = checker._clean_text(text)
+ for char in text:
+ cp = ord(char)
+ if checker._is_chinese_char(cp):
+ return True
+ return False
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ # image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ print(controlnet_conditioning_scale)
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError(
+ "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
+ "The conditioning scale must be fixed across the batch."
+ )
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ mode: Optional[str] = "generate",
+ draw_pos: Optional[Union[str, torch.Tensor]] = None,
+ ori_image: Optional[Union[str, torch.Tensor]] = None,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single
+ ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple
+ ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ # image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ prompt, texts = self.modify_prompt(prompt)
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
+ prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(
+ prompt,
+ texts,
+ negative_prompt,
+ num_images_per_prompt,
+ mode,
+ draw_pos,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 3.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ guided_hint = self.auxiliary_latent_module(
+ text_info=text_info,
+ mode=mode,
+ draw_pos=draw_pos,
+ ori_image=ori_image,
+ num_images_per_prompt=num_images_per_prompt,
+ np_hint=np_hint,
+ )
+ height, width = 512, 512
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ # 7.2 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ is_unet_compiled = is_compiled_module(self.unet)
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Relevant thread:
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ torch._inductor.cudagraph_mark_step_begin()
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input.to(self.controlnet.dtype),
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=guided_hint,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def to(self, *args, **kwargs):
+ super().to(*args, **kwargs)
+ self.text_embedding_module.to(*args, **kwargs)
+ self.auxiliary_latent_module.to(*args, **kwargs)
+ return self
diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py
new file mode 100644
index 000000000000..5965ceed1370
--- /dev/null
+++ b/examples/research_projects/anytext/anytext_controlnet.py
@@ -0,0 +1,463 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).
+# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie
+# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license
+#
+# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
+
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import register_to_config
+from diffusers.models.controlnets.controlnet import (
+ ControlNetModel,
+ ControlNetOutput,
+)
+from diffusers.utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class AnyTextControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ glyph_channels=1,
+ position_channels=1,
+ ):
+ super().__init__()
+
+ self.glyph_block = nn.Sequential(
+ nn.Conv2d(glyph_channels, 8, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(8, 8, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(8, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(32, 32, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(96, 96, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ )
+
+ self.position_block = nn.Sequential(
+ nn.Conv2d(position_channels, 8, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(8, 8, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(8, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(32, 32, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(32, 64, 3, padding=1, stride=2),
+ nn.SiLU(),
+ )
+
+ self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)
+
+ def forward(self, glyphs, positions, text_info):
+ glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))
+ position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))
+ guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
+
+ return guided_hint
+
+
+class AnyTextControlNetModel(ControlNetModel):
+ """
+ A AnyTextControlNetModel model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter.
+ addition_embed_type_num_heads (`int`, defaults to 64):
+ The number of heads to use for the `TextTimeEmbedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 1,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__(
+ in_channels,
+ conditioning_channels,
+ flip_sin_to_cos,
+ freq_shift,
+ down_block_types,
+ mid_block_type,
+ only_cross_attention,
+ block_out_channels,
+ layers_per_block,
+ downsample_padding,
+ mid_block_scale_factor,
+ act_fn,
+ norm_num_groups,
+ norm_eps,
+ cross_attention_dim,
+ transformer_layers_per_block,
+ encoder_hid_dim,
+ encoder_hid_dim_type,
+ attention_head_dim,
+ num_attention_heads,
+ use_linear_projection,
+ class_embed_type,
+ addition_embed_type,
+ addition_time_embed_dim,
+ num_class_embeds,
+ upcast_attention,
+ resnet_time_scale_shift,
+ projection_class_embeddings_input_dim,
+ controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels,
+ global_pool_conditions,
+ addition_embed_type_num_heads,
+ )
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ glyph_channels=conditioning_channels,
+ position_channels=conditioning_channels,
+ )
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
+ """
+ The [`~PromptDiffusionControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ #controlnet_cond (`torch.Tensor`):
+ # The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ # elif channel_order == "bgr":
+ # controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)
+ sample = sample + controlnet_cond
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # 5. Control net blocks
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+# Copied from diffusers.models.controlnet.zero_module
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/examples/research_projects/anytext/ocr_recog/RNN.py b/examples/research_projects/anytext/ocr_recog/RNN.py
new file mode 100755
index 000000000000..aec796d987c0
--- /dev/null
+++ b/examples/research_projects/anytext/ocr_recog/RNN.py
@@ -0,0 +1,209 @@
+import torch
+from torch import nn
+
+from .RecSVTR import Block
+
+
+class Swish(nn.Module):
+ def __int__(self):
+ super(Swish, self).__int__()
+
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class Im2Im(nn.Module):
+ def __init__(self, in_channels, **kwargs):
+ super().__init__()
+ self.out_channels = in_channels
+
+ def forward(self, x):
+ return x
+
+
+class Im2Seq(nn.Module):
+ def __init__(self, in_channels, **kwargs):
+ super().__init__()
+ self.out_channels = in_channels
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # assert H == 1
+ x = x.reshape(B, C, H * W)
+ x = x.permute((0, 2, 1))
+ return x
+
+
+class EncoderWithRNN(nn.Module):
+ def __init__(self, in_channels, **kwargs):
+ super(EncoderWithRNN, self).__init__()
+ hidden_size = kwargs.get("hidden_size", 256)
+ self.out_channels = hidden_size * 2
+ self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True)
+
+ def forward(self, x):
+ self.lstm.flatten_parameters()
+ x, _ = self.lstm(x)
+ return x
+
+
+class SequenceEncoder(nn.Module):
+ def __init__(self, in_channels, encoder_type="rnn", **kwargs):
+ super(SequenceEncoder, self).__init__()
+ self.encoder_reshape = Im2Seq(in_channels)
+ self.out_channels = self.encoder_reshape.out_channels
+ self.encoder_type = encoder_type
+ if encoder_type == "reshape":
+ self.only_reshape = True
+ else:
+ support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR}
+ assert encoder_type in support_encoder_dict, "{} must in {}".format(
+ encoder_type, support_encoder_dict.keys()
+ )
+
+ self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs)
+ self.out_channels = self.encoder.out_channels
+ self.only_reshape = False
+
+ def forward(self, x):
+ if self.encoder_type != "svtr":
+ x = self.encoder_reshape(x)
+ if not self.only_reshape:
+ x = self.encoder(x)
+ return x
+ else:
+ x = self.encoder(x)
+ x = self.encoder_reshape(x)
+ return x
+
+
+class ConvBNLayer(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
+ bias=bias_attr,
+ )
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = Swish()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class EncoderWithSVTR(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ dims=64, # XS
+ depth=2,
+ hidden_dims=120,
+ use_guide=False,
+ num_heads=8,
+ qkv_bias=True,
+ mlp_ratio=2.0,
+ drop_rate=0.1,
+ attn_drop_rate=0.1,
+ drop_path=0.0,
+ qk_scale=None,
+ ):
+ super(EncoderWithSVTR, self).__init__()
+ self.depth = depth
+ self.use_guide = use_guide
+ self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish")
+ self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish")
+
+ self.svtr_block = nn.ModuleList(
+ [
+ Block(
+ dim=hidden_dims,
+ num_heads=num_heads,
+ mixer="Global",
+ HW=None,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer="swish",
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-05,
+ prenorm=False,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
+ self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
+ # last conv-nxn, the input is concat of input tensor and conv3 output tensor
+ self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish")
+
+ self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
+ self.out_channels = dims
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ # weight initialization
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.ConvTranspose2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ # for use guide
+ if self.use_guide:
+ z = x.clone()
+ z.stop_gradient = True
+ else:
+ z = x
+ # for short cut
+ h = z
+ # reduce dim
+ z = self.conv1(z)
+ z = self.conv2(z)
+ # SVTR global block
+ B, C, H, W = z.shape
+ z = z.flatten(2).permute(0, 2, 1)
+
+ for blk in self.svtr_block:
+ z = blk(z)
+
+ z = self.norm(z)
+ # last stage
+ z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
+ z = self.conv3(z)
+ z = torch.cat((h, z), dim=1)
+ z = self.conv1x1(self.conv4(z))
+
+ return z
+
+
+if __name__ == "__main__":
+ svtrRNN = EncoderWithSVTR(56)
+ print(svtrRNN)
diff --git a/examples/research_projects/anytext/ocr_recog/RecCTCHead.py b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py
new file mode 100755
index 000000000000..c066c6202b19
--- /dev/null
+++ b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py
@@ -0,0 +1,45 @@
+from torch import nn
+
+
+class CTCHead(nn.Module):
+ def __init__(
+ self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs
+ ):
+ super(CTCHead, self).__init__()
+ if mid_channels is None:
+ self.fc = nn.Linear(
+ in_channels,
+ out_channels,
+ bias=True,
+ )
+ else:
+ self.fc1 = nn.Linear(
+ in_channels,
+ mid_channels,
+ bias=True,
+ )
+ self.fc2 = nn.Linear(
+ mid_channels,
+ out_channels,
+ bias=True,
+ )
+
+ self.out_channels = out_channels
+ self.mid_channels = mid_channels
+ self.return_feats = return_feats
+
+ def forward(self, x, labels=None):
+ if self.mid_channels is None:
+ predicts = self.fc(x)
+ else:
+ x = self.fc1(x)
+ predicts = self.fc2(x)
+
+ if self.return_feats:
+ result = {}
+ result["ctc"] = predicts
+ result["ctc_neck"] = x
+ else:
+ result = predicts
+
+ return result
diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py
new file mode 100755
index 000000000000..872ccade69e0
--- /dev/null
+++ b/examples/research_projects/anytext/ocr_recog/RecModel.py
@@ -0,0 +1,49 @@
+from torch import nn
+
+from .RecCTCHead import CTCHead
+from .RecMv1_enhance import MobileNetV1Enhance
+from .RNN import Im2Im, Im2Seq, SequenceEncoder
+
+
+backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
+neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
+head_dict = {"CTCHead": CTCHead}
+
+
+class RecModel(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ assert "in_channels" in config, "in_channels must in model config"
+ backbone_type = config["backbone"].pop("type")
+ assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
+ self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"])
+
+ neck_type = config["neck"].pop("type")
+ assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
+ self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"])
+
+ head_type = config["head"].pop("type")
+ assert head_type in head_dict, f"head.type must in {head_dict}"
+ self.head = head_dict[head_type](self.neck.out_channels, **config["head"])
+
+ self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
+
+ def load_3rd_state_dict(self, _3rd_name, _state):
+ self.backbone.load_3rd_state_dict(_3rd_name, _state)
+ self.neck.load_3rd_state_dict(_3rd_name, _state)
+ self.head.load_3rd_state_dict(_3rd_name, _state)
+
+ def forward(self, x):
+ import torch
+
+ x = x.to(torch.float32)
+ x = self.backbone(x)
+ x = self.neck(x)
+ x = self.head(x)
+ return x
+
+ def encode(self, x):
+ x = self.backbone(x)
+ x = self.neck(x)
+ x = self.head.ctc_encoder(x)
+ return x
diff --git a/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py
new file mode 100644
index 000000000000..df41519b2713
--- /dev/null
+++ b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py
@@ -0,0 +1,197 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .common import Activation
+
+
+class ConvBNLayer(nn.Module):
+ def __init__(
+ self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish"
+ ):
+ super(ConvBNLayer, self).__init__()
+ self.act = act
+ self._conv = nn.Conv2d(
+ in_channels=num_channels,
+ out_channels=num_filters,
+ kernel_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ bias=False,
+ )
+
+ self._batch_norm = nn.BatchNorm2d(
+ num_filters,
+ )
+ if self.act is not None:
+ self._act = Activation(act_type=act, inplace=True)
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+ if self.act is not None:
+ y = self._act(y)
+ return y
+
+
+class DepthwiseSeparable(nn.Module):
+ def __init__(
+ self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False
+ ):
+ super(DepthwiseSeparable, self).__init__()
+ self.use_se = use_se
+ self._depthwise_conv = ConvBNLayer(
+ num_channels=num_channels,
+ num_filters=int(num_filters1 * scale),
+ filter_size=dw_size,
+ stride=stride,
+ padding=padding,
+ num_groups=int(num_groups * scale),
+ )
+ if use_se:
+ self._se = SEModule(int(num_filters1 * scale))
+ self._pointwise_conv = ConvBNLayer(
+ num_channels=int(num_filters1 * scale),
+ filter_size=1,
+ num_filters=int(num_filters2 * scale),
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, inputs):
+ y = self._depthwise_conv(inputs)
+ if self.use_se:
+ y = self._se(y)
+ y = self._pointwise_conv(y)
+ return y
+
+
+class MobileNetV1Enhance(nn.Module):
+ def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs):
+ super().__init__()
+ self.scale = scale
+ self.block_list = []
+
+ self.conv1 = ConvBNLayer(
+ num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1
+ )
+
+ conv2_1 = DepthwiseSeparable(
+ num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale
+ )
+ self.block_list.append(conv2_1)
+
+ conv2_2 = DepthwiseSeparable(
+ num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale
+ )
+ self.block_list.append(conv2_2)
+
+ conv3_1 = DepthwiseSeparable(
+ num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale
+ )
+ self.block_list.append(conv3_1)
+
+ conv3_2 = DepthwiseSeparable(
+ num_channels=int(128 * scale),
+ num_filters1=128,
+ num_filters2=256,
+ num_groups=128,
+ stride=(2, 1),
+ scale=scale,
+ )
+ self.block_list.append(conv3_2)
+
+ conv4_1 = DepthwiseSeparable(
+ num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale
+ )
+ self.block_list.append(conv4_1)
+
+ conv4_2 = DepthwiseSeparable(
+ num_channels=int(256 * scale),
+ num_filters1=256,
+ num_filters2=512,
+ num_groups=256,
+ stride=(2, 1),
+ scale=scale,
+ )
+ self.block_list.append(conv4_2)
+
+ for _ in range(5):
+ conv5 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=512,
+ num_groups=512,
+ stride=1,
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=False,
+ )
+ self.block_list.append(conv5)
+
+ conv5_6 = DepthwiseSeparable(
+ num_channels=int(512 * scale),
+ num_filters1=512,
+ num_filters2=1024,
+ num_groups=512,
+ stride=(2, 1),
+ dw_size=5,
+ padding=2,
+ scale=scale,
+ use_se=True,
+ )
+ self.block_list.append(conv5_6)
+
+ conv6 = DepthwiseSeparable(
+ num_channels=int(1024 * scale),
+ num_filters1=1024,
+ num_filters2=1024,
+ num_groups=1024,
+ stride=last_conv_stride,
+ dw_size=5,
+ padding=2,
+ use_se=True,
+ scale=scale,
+ )
+ self.block_list.append(conv6)
+
+ self.block_list = nn.Sequential(*self.block_list)
+ if last_pool_type == "avg":
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
+ else:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ self.out_channels = int(1024 * scale)
+
+ def forward(self, inputs):
+ y = self.conv1(inputs)
+ y = self.block_list(y)
+ y = self.pool(y)
+ return y
+
+
+def hardsigmoid(x):
+ return F.relu6(x + 3.0, inplace=True) / 6.0
+
+
+class SEModule(nn.Module):
+ def __init__(self, channel, reduction=4):
+ super(SEModule, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = nn.Conv2d(
+ in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True
+ )
+ self.conv2 = nn.Conv2d(
+ in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True
+ )
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = hardsigmoid(outputs)
+ x = torch.mul(inputs, outputs)
+
+ return x
diff --git a/examples/research_projects/anytext/ocr_recog/RecSVTR.py b/examples/research_projects/anytext/ocr_recog/RecSVTR.py
new file mode 100644
index 000000000000..590a96995b26
--- /dev/null
+++ b/examples/research_projects/anytext/ocr_recog/RecSVTR.py
@@ -0,0 +1,570 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional
+from torch.nn.init import ones_, trunc_normal_, zeros_
+
+
+def drop_path(x, drop_prob=0.0, training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = torch.tensor(1 - drop_prob)
+ shape = (x.size()[0],) + (1,) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
+ random_tensor = torch.floor(random_tensor) # binarize
+ output = x.divide(keep_prob) * random_tensor
+ return output
+
+
+class Swish(nn.Module):
+ def __int__(self):
+ super(Swish, self).__int__()
+
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class ConvBNLayer(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
+ bias=bias_attr,
+ )
+ self.norm = nn.BatchNorm2d(out_channels)
+ self.act = act()
+
+ def forward(self, inputs):
+ out = self.conv(inputs)
+ out = self.norm(out)
+ out = self.act(out)
+ return out
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Identity(nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, input):
+ return input
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ if isinstance(act_layer, str):
+ self.act = Swish()
+ else:
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvMixer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ HW=(8, 25),
+ local_k=(3, 3),
+ ):
+ super().__init__()
+ self.HW = HW
+ self.dim = dim
+ self.local_mixer = nn.Conv2d(
+ dim,
+ dim,
+ local_k,
+ 1,
+ (local_k[0] // 2, local_k[1] // 2),
+ groups=num_heads,
+ # weight_attr=ParamAttr(initializer=KaimingNormal())
+ )
+
+ def forward(self, x):
+ h = self.HW[0]
+ w = self.HW[1]
+ x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
+ x = self.local_mixer(x)
+ x = x.flatten(2).transpose([0, 2, 1])
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ mixer="Global",
+ HW=(8, 25),
+ local_k=(7, 11),
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.HW = HW
+ if HW is not None:
+ H = HW[0]
+ W = HW[1]
+ self.N = H * W
+ self.C = dim
+ if mixer == "Local" and HW is not None:
+ hk = local_k[0]
+ wk = local_k[1]
+ mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
+ for h in range(0, H):
+ for w in range(0, W):
+ mask[h * W + w, h : h + hk, w : w + wk] = 0.0
+ mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1)
+ mask_inf = torch.full([H * W, H * W], fill_value=float("-inf"))
+ mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
+ self.mask = mask[None, None, :]
+ # self.mask = mask.unsqueeze([0, 1])
+ self.mixer = mixer
+
+ def forward(self, x):
+ if self.HW is not None:
+ N = self.N
+ C = self.C
+ else:
+ _, N, C = x.shape
+ qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+
+ attn = q.matmul(k.permute((0, 1, 3, 2)))
+ if self.mixer == "Local":
+ attn += self.mask
+ attn = functional.softmax(attn, dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mixer="Global",
+ local_mixer=(7, 11),
+ HW=(8, 25),
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer="nn.LayerNorm",
+ epsilon=1e-6,
+ prenorm=True,
+ ):
+ super().__init__()
+ if isinstance(norm_layer, str):
+ self.norm1 = eval(norm_layer)(dim, eps=epsilon)
+ else:
+ self.norm1 = norm_layer(dim)
+ if mixer == "Global" or mixer == "Local":
+ self.mixer = Attention(
+ dim,
+ num_heads=num_heads,
+ mixer=mixer,
+ HW=HW,
+ local_k=local_mixer,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ elif mixer == "Conv":
+ self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
+ else:
+ raise TypeError("The mixer must be one of [Global, Local, Conv]")
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
+ if isinstance(norm_layer, str):
+ self.norm2 = eval(norm_layer)(dim, eps=epsilon)
+ else:
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp_ratio = mlp_ratio
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.prenorm = prenorm
+
+ def forward(self, x):
+ if self.prenorm:
+ x = self.norm1(x + self.drop_path(self.mixer(x)))
+ x = self.norm2(x + self.drop_path(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.mixer(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
+ super().__init__()
+ num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
+ self.img_size = img_size
+ self.num_patches = num_patches
+ self.embed_dim = embed_dim
+ self.norm = None
+ if sub_num == 2:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False,
+ ),
+ )
+ if sub_num == 3:
+ self.proj = nn.Sequential(
+ ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=embed_dim // 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 4,
+ out_channels=embed_dim // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False,
+ ),
+ ConvBNLayer(
+ in_channels=embed_dim // 2,
+ out_channels=embed_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ act=nn.GELU,
+ bias_attr=False,
+ ),
+ )
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).permute(0, 2, 1)
+ return x
+
+
+class SubSample(nn.Module):
+ def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None):
+ super().__init__()
+ self.types = types
+ if types == "Pool":
+ self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
+ self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
+ self.proj = nn.Linear(in_channels, out_channels)
+ else:
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ # weight_attr=ParamAttr(initializer=KaimingNormal())
+ )
+ self.norm = eval(sub_norm)(out_channels)
+ if act is not None:
+ self.act = act()
+ else:
+ self.act = None
+
+ def forward(self, x):
+ if self.types == "Pool":
+ x1 = self.avgpool(x)
+ x2 = self.maxpool(x)
+ x = (x1 + x2) * 0.5
+ out = self.proj(x.flatten(2).permute((0, 2, 1)))
+ else:
+ x = self.conv(x)
+ out = x.flatten(2).permute((0, 2, 1))
+ out = self.norm(out)
+ if self.act is not None:
+ out = self.act(out)
+
+ return out
+
+
+class SVTRNet(nn.Module):
+ def __init__(
+ self,
+ img_size=[48, 100],
+ in_channels=3,
+ embed_dim=[64, 128, 256],
+ depth=[3, 6, 3],
+ num_heads=[2, 4, 8],
+ mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
+ local_mixer=[[7, 11], [7, 11], [7, 11]],
+ patch_merging="Conv", # Conv, Pool, None
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ last_drop=0.1,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer="nn.LayerNorm",
+ sub_norm="nn.LayerNorm",
+ epsilon=1e-6,
+ out_channels=192,
+ out_char_num=25,
+ block_unit="Block",
+ act="nn.GELU",
+ last_stage=True,
+ sub_num=2,
+ prenorm=True,
+ use_lenhead=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.img_size = img_size
+ self.embed_dim = embed_dim
+ self.out_channels = out_channels
+ self.prenorm = prenorm
+ patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num
+ )
+ num_patches = self.patch_embed.num_patches
+ self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
+ # self.pos_embed = self.create_parameter(
+ # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
+
+ # self.add_parameter("pos_embed", self.pos_embed)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ Block_unit = eval(block_unit)
+
+ dpr = np.linspace(0, drop_path_rate, sum(depth))
+ self.blocks1 = nn.ModuleList(
+ [
+ Block_unit(
+ dim=embed_dim[0],
+ num_heads=num_heads[0],
+ mixer=mixer[0 : depth[0]][i],
+ HW=self.HW,
+ local_mixer=local_mixer[0],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[0 : depth[0]][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm,
+ )
+ for i in range(depth[0])
+ ]
+ )
+ if patch_merging is not None:
+ self.sub_sample1 = SubSample(
+ embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
+ )
+ HW = [self.HW[0] // 2, self.HW[1]]
+ else:
+ HW = self.HW
+ self.patch_merging = patch_merging
+ self.blocks2 = nn.ModuleList(
+ [
+ Block_unit(
+ dim=embed_dim[1],
+ num_heads=num_heads[1],
+ mixer=mixer[depth[0] : depth[0] + depth[1]][i],
+ HW=HW,
+ local_mixer=local_mixer[1],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm,
+ )
+ for i in range(depth[1])
+ ]
+ )
+ if patch_merging is not None:
+ self.sub_sample2 = SubSample(
+ embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
+ )
+ HW = [self.HW[0] // 4, self.HW[1]]
+ else:
+ HW = self.HW
+ self.blocks3 = nn.ModuleList(
+ [
+ Block_unit(
+ dim=embed_dim[2],
+ num_heads=num_heads[2],
+ mixer=mixer[depth[0] + depth[1] :][i],
+ HW=HW,
+ local_mixer=local_mixer[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ act_layer=eval(act),
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[depth[0] + depth[1] :][i],
+ norm_layer=norm_layer,
+ epsilon=epsilon,
+ prenorm=prenorm,
+ )
+ for i in range(depth[2])
+ ]
+ )
+ self.last_stage = last_stage
+ if last_stage:
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
+ self.last_conv = nn.Conv2d(
+ in_channels=embed_dim[2],
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.hardswish = nn.Hardswish()
+ self.dropout = nn.Dropout(p=last_drop)
+ if not prenorm:
+ self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
+ self.use_lenhead = use_lenhead
+ if use_lenhead:
+ self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
+ self.hardswish_len = nn.Hardswish()
+ self.dropout_len = nn.Dropout(p=last_drop)
+
+ trunc_normal_(self.pos_embed, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks1:
+ x = blk(x)
+ if self.patch_merging is not None:
+ x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
+ for blk in self.blocks2:
+ x = blk(x)
+ if self.patch_merging is not None:
+ x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
+ for blk in self.blocks3:
+ x = blk(x)
+ if not self.prenorm:
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ if self.use_lenhead:
+ len_x = self.len_conv(x.mean(1))
+ len_x = self.dropout_len(self.hardswish_len(len_x))
+ if self.last_stage:
+ if self.patch_merging is not None:
+ h = self.HW[0] // 4
+ else:
+ h = self.HW[0]
+ x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]]))
+ x = self.last_conv(x)
+ x = self.hardswish(x)
+ x = self.dropout(x)
+ if self.use_lenhead:
+ return x, len_x
+ return x
+
+
+if __name__ == "__main__":
+ a = torch.rand(1, 3, 48, 100)
+ svtr = SVTRNet()
+
+ out = svtr(a)
+ print(svtr)
+ print(out.size())
diff --git a/examples/research_projects/anytext/ocr_recog/common.py b/examples/research_projects/anytext/ocr_recog/common.py
new file mode 100644
index 000000000000..207a95b17d0e
--- /dev/null
+++ b/examples/research_projects/anytext/ocr_recog/common.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Hswish(nn.Module):
+ def __init__(self, inplace=True):
+ super(Hswish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
+
+
+# out = max(0, min(1, slop*x+offset))
+# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
+class Hsigmoid(nn.Module):
+ def __init__(self, inplace=True):
+ super(Hsigmoid, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ # torch: F.relu6(x + 3., inplace=self.inplace) / 6.
+ # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
+ return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
+
+
+class GELU(nn.Module):
+ def __init__(self, inplace=True):
+ super(GELU, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return torch.nn.functional.gelu(x)
+
+
+class Swish(nn.Module):
+ def __init__(self, inplace=True):
+ super(Swish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ if self.inplace:
+ x.mul_(torch.sigmoid(x))
+ return x
+ else:
+ return x * torch.sigmoid(x)
+
+
+class Activation(nn.Module):
+ def __init__(self, act_type, inplace=True):
+ super(Activation, self).__init__()
+ act_type = act_type.lower()
+ if act_type == "relu":
+ self.act = nn.ReLU(inplace=inplace)
+ elif act_type == "relu6":
+ self.act = nn.ReLU6(inplace=inplace)
+ elif act_type == "sigmoid":
+ raise NotImplementedError
+ elif act_type == "hard_sigmoid":
+ self.act = Hsigmoid(inplace)
+ elif act_type == "hard_swish":
+ self.act = Hswish(inplace=inplace)
+ elif act_type == "leakyrelu":
+ self.act = nn.LeakyReLU(inplace=inplace)
+ elif act_type == "gelu":
+ self.act = GELU(inplace=inplace)
+ elif act_type == "swish":
+ self.act = Swish(inplace=inplace)
+ else:
+ raise NotImplementedError
+
+ def forward(self, inputs):
+ return self.act(inputs)
diff --git a/examples/research_projects/anytext/ocr_recog/en_dict.txt b/examples/research_projects/anytext/ocr_recog/en_dict.txt
new file mode 100644
index 000000000000..7677d31b9d3f
--- /dev/null
+++ b/examples/research_projects/anytext/ocr_recog/en_dict.txt
@@ -0,0 +1,95 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+:
+;
+<
+=
+>
+?
+@
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+[
+\
+]
+^
+_
+`
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+{
+|
+}
+~
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+
diff --git a/examples/research_projects/autoencoderkl/README.md b/examples/research_projects/autoencoderkl/README.md
new file mode 100644
index 000000000000..c62018312da5
--- /dev/null
+++ b/examples/research_projects/autoencoderkl/README.md
@@ -0,0 +1,59 @@
+# AutoencoderKL training example
+
+## Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+## Training on CIFAR10
+
+Please replace the validation image with your own image.
+
+```bash
+accelerate launch train_autoencoderkl.py \
+ --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
+ --dataset_name=cifar10 \
+ --image_column=img \
+ --validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
+ --num_train_epochs 100 \
+ --gradient_accumulation_steps 2 \
+ --learning_rate 4.5e-6 \
+ --lr_scheduler cosine \
+ --report_to wandb \
+```
+
+## Training on ImageNet
+
+```bash
+accelerate launch train_autoencoderkl.py \
+ --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \
+ --num_train_epochs 100 \
+ --gradient_accumulation_steps 2 \
+ --learning_rate 4.5e-6 \
+ --lr_scheduler cosine \
+ --report_to wandb \
+ --mixed_precision bf16 \
+ --train_data_dir /path/to/ImageNet/train \
+ --validation_image ./image.png \
+ --decoder_only
+```
diff --git a/examples/research_projects/autoencoderkl/requirements.txt b/examples/research_projects/autoencoderkl/requirements.txt
new file mode 100644
index 000000000000..fe501252b46a
--- /dev/null
+++ b/examples/research_projects/autoencoderkl/requirements.txt
@@ -0,0 +1,15 @@
+accelerate>=0.16.0
+bitsandbytes
+datasets
+huggingface_hub
+lpips
+numpy
+packaging
+Pillow
+taming_transformers
+torch
+torchvision
+tqdm
+transformers
+wandb
+xformers
diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py
new file mode 100644
index 000000000000..31cf8414ac10
--- /dev/null
+++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py
@@ -0,0 +1,1061 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import gc
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+
+import accelerate
+import lpips
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from taming.modules.losses.vqperceptual import NLayerDiscriminator, hinge_d_loss, vanilla_d_loss, weights_init
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import AutoencoderKL
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.33.0.dev0")
+
+logger = get_logger(__name__)
+
+
+@torch.no_grad()
+def log_validation(vae, args, accelerator, weight_dtype, step, is_final_validation=False):
+ logger.info("Running validation... ")
+
+ if not is_final_validation:
+ vae = accelerator.unwrap_model(vae)
+ else:
+ vae = AutoencoderKL.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
+
+ images = []
+ inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
+
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ for i, validation_image in enumerate(args.validation_image):
+ validation_image = Image.open(validation_image).convert("RGB")
+ targets = image_transforms(validation_image).to(accelerator.device, weight_dtype)
+ targets = targets.unsqueeze(0)
+
+ with inference_ctx:
+ reconstructions = vae(targets).sample
+
+ images.append(torch.cat([targets.cpu(), reconstructions.cpu()], axis=0))
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(f"{tracker_key}: Original (left), Reconstruction (right)", np_images, step)
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ f"{tracker_key}: Original (left), Reconstruction (right)": [
+ wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if images is not None:
+ img_str = "You can find some example images below.\n\n"
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, "images.png"))
+ img_str += "\n"
+
+ model_description = f"""
+# autoencoderkl-{repo_id}
+
+These are autoencoderkl weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "image-to-image",
+ "diffusers",
+ "autoencoderkl",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a AutoencoderKL training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--model_config_name_or_path",
+ type=str,
+ default=None,
+ help="The config of the VAE model to train, leave as None to use standard VAE model configuration.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="autoencoderkl-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=4.5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--disc_learning_rate",
+ type=float,
+ default=4.5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--disc_lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help="A set of paths to the image be evaluated every `--validation_steps` and logged to `--report_to`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="train_autoencoderkl",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument(
+ "--rec_loss",
+ type=str,
+ default="l2",
+ help="The loss function for VAE reconstruction loss.",
+ )
+ parser.add_argument(
+ "--kl_scale",
+ type=float,
+ default=1e-6,
+ help="Scaling factor for the Kullback-Leibler divergence penalty term.",
+ )
+ parser.add_argument(
+ "--perceptual_scale",
+ type=float,
+ default=0.5,
+ help="Scaling factor for the LPIPS metric",
+ )
+ parser.add_argument(
+ "--disc_start",
+ type=int,
+ default=50001,
+ help="Start for the discriminator",
+ )
+ parser.add_argument(
+ "--disc_factor",
+ type=float,
+ default=1.0,
+ help="Scaling factor for the discriminator",
+ )
+ parser.add_argument(
+ "--disc_scale",
+ type=float,
+ default=1.0,
+ help="Scaling factor for the discriminator",
+ )
+ parser.add_argument(
+ "--disc_loss",
+ type=str,
+ default="hinge",
+ help="Loss function for the discriminator",
+ )
+ parser.add_argument(
+ "--decoder_only",
+ action="store_true",
+ help="Only train the VAE decoder.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.pretrained_model_name_or_path is not None and args.model_config_name_or_path is not None:
+ raise ValueError("Cannot specify both `--pretrained_model_name_or_path` and `--model_config_name_or_path`")
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the diffusion model."
+ )
+
+ return args
+
+
+def make_train_dataset(args, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ images = [image_transforms(image) for image in images]
+
+ examples["pixel_values"] = images
+
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ return train_dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ return {"pixel_values": pixel_values}
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load AutoencoderKL
+ if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None:
+ config = AutoencoderKL.load_config("stabilityai/sd-vae-ft-mse")
+ vae = AutoencoderKL.from_config(config)
+ elif args.pretrained_model_name_or_path is not None:
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, revision=args.revision)
+ else:
+ config = AutoencoderKL.load_config(args.model_config_name_or_path)
+ vae = AutoencoderKL.from_config(config)
+ if args.use_ema:
+ ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
+ perceptual_loss = lpips.LPIPS(net="vgg").eval()
+ discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
+ discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
+
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ sub_dir = "autoencoderkl_ema"
+ ema_vae.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ if isinstance(model, AutoencoderKL):
+ sub_dir = "autoencoderkl"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+ else:
+ sub_dir = "discriminator"
+ os.makedirs(os.path.join(output_dir, sub_dir), exist_ok=True)
+ torch.save(model.state_dict(), os.path.join(output_dir, sub_dir, "pytorch_model.bin"))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ if args.use_ema:
+ sub_dir = "autoencoderkl_ema"
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, sub_dir), AutoencoderKL)
+ ema_vae.load_state_dict(load_model.state_dict())
+ ema_vae.to(accelerator.device)
+ del load_model
+
+ # pop models so that they are not loaded again
+ model = models.pop()
+ load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(
+ os.path.join(input_dir, "discriminator", "pytorch_model.bin")
+ )
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ model = models.pop()
+ load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="autoencoderkl")
+ model.register_to_config(**load_model.config)
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(True)
+ if args.decoder_only:
+ vae.encoder.requires_grad_(False)
+ if getattr(vae, "quant_conv", None):
+ vae.quant_conv.requires_grad_(False)
+ vae.train()
+ discriminator.requires_grad_(True)
+ discriminator.train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ vae.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ vae.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if unwrap_model(vae).dtype != torch.float32:
+ raise ValueError(f"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ params_to_optimize = filter(lambda p: p.requires_grad, vae.parameters())
+ disc_params_to_optimize = filter(lambda p: p.requires_grad, discriminator.parameters())
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+ disc_optimizer = optimizer_class(
+ disc_params_to_optimize,
+ lr=args.disc_learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ train_dataset = make_train_dataset(args, accelerator)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+ disc_lr_scheduler = get_scheduler(
+ args.disc_lr_scheduler,
+ optimizer=disc_optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ (
+ vae,
+ discriminator,
+ optimizer,
+ disc_optimizer,
+ train_dataloader,
+ lr_scheduler,
+ disc_lr_scheduler,
+ ) = accelerator.prepare(
+ vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move VAE, perceptual loss and discriminator to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ perceptual_loss.to(accelerator.device, dtype=weight_dtype)
+ discriminator.to(accelerator.device, dtype=weight_dtype)
+ if args.use_ema:
+ ema_vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ vae.train()
+ discriminator.train()
+ for step, batch in enumerate(train_dataloader):
+ # Convert images to latent space and reconstruct from them
+ targets = batch["pixel_values"].to(dtype=weight_dtype)
+ posterior = accelerator.unwrap_model(vae).encode(targets).latent_dist
+ latents = posterior.sample()
+ reconstructions = accelerator.unwrap_model(vae).decode(latents).sample
+
+ if (step // args.gradient_accumulation_steps) % 2 == 0 or global_step < args.disc_start:
+ with accelerator.accumulate(vae):
+ # reconstruction loss. Pixel level differences between input vs output
+ if args.rec_loss == "l2":
+ rec_loss = F.mse_loss(reconstructions.float(), targets.float(), reduction="none")
+ elif args.rec_loss == "l1":
+ rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction="none")
+ else:
+ raise ValueError(f"Invalid reconstruction loss type: {args.rec_loss}")
+ # perceptual loss. The high level feature mean squared error loss
+ with torch.no_grad():
+ p_loss = perceptual_loss(reconstructions, targets)
+
+ rec_loss = rec_loss + args.perceptual_scale * p_loss
+ nll_loss = rec_loss
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+
+ kl_loss = posterior.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ logits_fake = discriminator(reconstructions)
+ g_loss = -torch.mean(logits_fake)
+ last_layer = accelerator.unwrap_model(vae).decoder.conv_out.weight
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ disc_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ disc_weight = torch.clamp(disc_weight, 0.0, 1e4).detach()
+ disc_weight = disc_weight * args.disc_scale
+ disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
+
+ loss = nll_loss + args.kl_scale * kl_loss + disc_weight * disc_factor * g_loss
+
+ logs = {
+ "loss": loss.detach().mean().item(),
+ "nll_loss": nll_loss.detach().mean().item(),
+ "rec_loss": rec_loss.detach().mean().item(),
+ "p_loss": p_loss.detach().mean().item(),
+ "kl_loss": kl_loss.detach().mean().item(),
+ "disc_weight": disc_weight.detach().mean().item(),
+ "disc_factor": disc_factor,
+ "g_loss": g_loss.detach().mean().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ }
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = vae.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+ else:
+ with accelerator.accumulate(discriminator):
+ logits_real = discriminator(targets)
+ logits_fake = discriminator(reconstructions)
+ disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
+ disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
+ d_loss = disc_factor * disc_loss(logits_real, logits_fake)
+ logs = {
+ "disc_loss": d_loss.detach().mean().item(),
+ "logits_real": logits_real.detach().mean().item(),
+ "logits_fake": logits_fake.detach().mean().item(),
+ "disc_lr": disc_lr_scheduler.get_last_lr()[0],
+ }
+ accelerator.backward(d_loss)
+ if accelerator.sync_gradients:
+ params_to_clip = discriminator.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ disc_optimizer.step()
+ disc_lr_scheduler.step()
+ disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ if args.use_ema:
+ ema_vae.step(vae.parameters())
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step == 1 or global_step % args.validation_steps == 0:
+ if args.use_ema:
+ ema_vae.store(vae.parameters())
+ ema_vae.copy_to(vae.parameters())
+ image_logs = log_validation(
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ ema_vae.restore(vae.parameters())
+
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ vae = accelerator.unwrap_model(vae)
+ discriminator = accelerator.unwrap_model(discriminator)
+ if args.use_ema:
+ ema_vae.copy_to(vae.parameters())
+ vae.save_pretrained(args.output_dir)
+ torch.save(discriminator.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin"))
+ # Run a final round of validation.
+ image_logs = None
+ image_logs = log_validation(
+ vae=vae,
+ args=args,
+ accelerator=accelerator,
+ weight_dtype=weight_dtype,
+ step=global_step,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
index eccc539f230c..2bea064cdb72 100644
--- a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
+++ b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py
index 88a5d93d8edf..829b0031156e 100644
--- a/examples/research_projects/controlnet/train_controlnet_webdataset.py
+++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -381,9 +381,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
- formatted_images = []
-
- formatted_images.append(np.asarray(validation_image))
+ formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
index cdc096190f08..ed245e9cef7d 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
index cd1ef265d23e..66a7a3652947 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/flux_lora_quantization/README.md b/examples/research_projects/flux_lora_quantization/README.md
new file mode 100644
index 000000000000..51005b640221
--- /dev/null
+++ b/examples/research_projects/flux_lora_quantization/README.md
@@ -0,0 +1,167 @@
+## LoRA fine-tuning Flux.1 Dev with quantization
+
+> [!NOTE]
+> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
+
+This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:
+
+* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
+ * Even though optional, we load the T5-xxl in NF4 to further reduce the memory foot-print.
+* `train_dreambooth_lora_flux_miniature.py` takes care of training:
+ * Since we already precomputed the text embeddings, we don't load the text encoders.
+ * We load the VAE and use it to precompute the image latents and we then delete it.
+ * Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training.
+ * Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.
+ * Train!
+
+To run training in a memory-optimized manner, we additionally use:
+
+* 8Bit Adam
+* Gradient checkpointing
+
+We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow.
+
+## Training
+
+Ensure you have installed the required libraries:
+
+```bash
+pip install -U transformers accelerate bitsandbytes peft datasets
+pip install git+https://github.com/huggingface/diffusers -U
+```
+
+Now, compute the text embeddings:
+
+```bash
+python compute_embeddings.py
+```
+
+It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
+
+```bash
+huggingface-cli
+```
+
+Then launch:
+
+```bash
+accelerate launch --config_file=accelerate.yaml \
+ train_dreambooth_lora_flux_miniature.py \
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
+ --data_df_path="embeddings.parquet" \
+ --output_dir="yarn_art_lora_flux_nf4" \
+ --mixed_precision="fp16" \
+ --use_8bit_adam \
+ --weighting_scheme="none" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --repeats=1 \
+ --learning_rate=1e-4 \
+ --guidance_scale=1 \
+ --report_to="wandb" \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --cache_latents \
+ --rank=4 \
+ --max_train_steps=700 \
+ --seed="0"
+```
+
+We can direcly pass a quantized checkpoint path, too:
+
+```diff
++ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
+```
+
+Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`.
+
+We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:
+
+```bash
+pip install -Uq deepspeed
+```
+
+And then launch:
+
+```bash
+accelerate launch --config_file=ds2.yaml \
+ train_dreambooth_lora_flux_miniature.py \
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
+ --data_df_path="embeddings.parquet" \
+ --output_dir="yarn_art_lora_flux_nf4" \
+ --mixed_precision="no" \
+ --use_8bit_adam \
+ --weighting_scheme="none" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --repeats=1 \
+ --learning_rate=1e-4 \
+ --guidance_scale=1 \
+ --report_to="wandb" \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --cache_latents \
+ --rank=4 \
+ --max_train_steps=700 \
+ --seed="0"
+```
+
+## Inference
+
+When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:
+
+1. First, load the original model and merge the LoRA params into it:
+
+```py
+from diffusers import FluxPipeline
+import torch
+
+ckpt_id = "black-forest-labs/FLUX.1-dev"
+pipeline = FluxPipeline.from_pretrained(
+ ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
+)
+pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
+pipeline.fuse_lora()
+pipeline.unload_lora_weights()
+
+pipeline.transformer.save_pretrained("fused_transformer")
+```
+
+2. Quantize the model and run inference
+
+```py
+from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
+import torch
+
+ckpt_id = "black-forest-labs/FLUX.1-dev"
+bnb_4bit_compute_dtype = torch.float16
+nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
+)
+transformer = FluxTransformer2DModel.from_pretrained(
+ "fused_transformer",
+ quantization_config=nf4_config,
+ torch_dtype=bnb_4bit_compute_dtype,
+)
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
+)
+pipeline.enable_model_cpu_offload()
+
+image = pipeline(
+ "a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
+).images[0]
+image.save("yarn_merged.png")
+```
+
+| Dequantize, merge, quantize | Merging directly into quantized model |
+|-------|-------|
+|  |  |
+
+As we can notice the first column result follows the style more closely.
diff --git a/examples/research_projects/flux_lora_quantization/accelerate.yaml b/examples/research_projects/flux_lora_quantization/accelerate.yaml
new file mode 100644
index 000000000000..309e13cc140a
--- /dev/null
+++ b/examples/research_projects/flux_lora_quantization/accelerate.yaml
@@ -0,0 +1,17 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: NO
+downcast_bf16: 'no'
+enable_cpu_affinity: true
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/examples/research_projects/flux_lora_quantization/compute_embeddings.py b/examples/research_projects/flux_lora_quantization/compute_embeddings.py
new file mode 100644
index 000000000000..1878b70f1372
--- /dev/null
+++ b/examples/research_projects/flux_lora_quantization/compute_embeddings.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+
+import pandas as pd
+import torch
+from datasets import load_dataset
+from huggingface_hub.utils import insecure_hashlib
+from tqdm.auto import tqdm
+from transformers import T5EncoderModel
+
+from diffusers import FluxPipeline
+
+
+MAX_SEQ_LENGTH = 77
+OUTPUT_PATH = "embeddings.parquet"
+
+
+def generate_image_hash(image):
+ return insecure_hashlib.sha256(image.tobytes()).hexdigest()
+
+
+def load_flux_dev_pipeline():
+ id = "black-forest-labs/FLUX.1-dev"
+ text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
+ pipeline = FluxPipeline.from_pretrained(
+ id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
+ )
+ return pipeline
+
+
+@torch.no_grad()
+def compute_embeddings(pipeline, prompts, max_sequence_length):
+ all_prompt_embeds = []
+ all_pooled_prompt_embeds = []
+ all_text_ids = []
+ for prompt in tqdm(prompts, desc="Encoding prompts."):
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
+ all_prompt_embeds.append(prompt_embeds)
+ all_pooled_prompt_embeds.append(pooled_prompt_embeds)
+ all_text_ids.append(text_ids)
+
+ max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
+ print(f"Max memory allocated: {max_memory:.3f} GB")
+ return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
+
+
+def run(args):
+ dataset = load_dataset("Norod78/Yarn-art-style", split="train")
+ image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
+ all_prompts = list(image_prompts.values())
+ print(f"{len(all_prompts)=}")
+
+ pipeline = load_flux_dev_pipeline()
+ all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
+ pipeline, all_prompts, args.max_sequence_length
+ )
+
+ data = []
+ for i, (image_hash, _) in enumerate(image_prompts.items()):
+ data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
+ print(f"{len(data)=}")
+
+ # Create a DataFrame
+ embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
+ df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
+ print(f"{len(df)=}")
+
+ # Convert embedding lists to arrays (for proper storage in parquet)
+ for col in embedding_cols:
+ df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
+
+ # Save the dataframe to a parquet file
+ df.to_parquet(args.output_path)
+ print(f"Data successfully serialized to {args.output_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=MAX_SEQ_LENGTH,
+ help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
+ )
+ parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
+ args = parser.parse_args()
+
+ run(args)
diff --git a/examples/research_projects/flux_lora_quantization/ds2.yaml b/examples/research_projects/flux_lora_quantization/ds2.yaml
new file mode 100644
index 000000000000..beed28fd90ab
--- /dev/null
+++ b/examples/research_projects/flux_lora_quantization/ds2.yaml
@@ -0,0 +1,23 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ gradient_accumulation_steps: 1
+ gradient_clipping: 1.0
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+machine_rank: 0
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
new file mode 100644
index 000000000000..ccaf3164a00c
--- /dev/null
+++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
@@ -0,0 +1,1200 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator, DistributedType
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ BitsAndBytesConfig,
+ FlowMatchEulerDiscreteScheduler,
+ FluxPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ free_memory,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ pass
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ base_model: str = None,
+ instance_prompt=None,
+ repo_folder=None,
+ quantization_config=None,
+):
+ widget_dict = []
+
+ model_description = f"""
+# Flux DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).
+
+Was LoRA for the text encoder enabled? False.
+
+Quantization config:
+
+```yaml
+{quantization_config}
+```
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## Usage
+
+TODO
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux",
+ "flux-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--quantized_model_path",
+ type=str,
+ default=None,
+ help="Path to the quantized model.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--data_df_path",
+ type=str,
+ default=None,
+ help=("Path to the parquet file serialized with compute_embeddings.py."),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=77,
+ help="Used for reading the embeddings. Needs to be the same as used during `compute_embeddings.py`.",
+ )
+
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth-lora-nf4",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ choices=["AdamW", "Prodigy", "AdEMAMix"],
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+ parser.add_argument(
+ "--use_8bit_ademamix",
+ action="store_true",
+ help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ def __init__(
+ self,
+ data_df_path,
+ dataset_name,
+ size=1024,
+ max_sequence_length=77,
+ center_crop=False,
+ ):
+ # Logistics
+ self.size = size
+ self.center_crop = center_crop
+ self.max_sequence_length = max_sequence_length
+
+ self.data_df_path = Path(data_df_path)
+ if not self.data_df_path.exists():
+ raise ValueError("`data_df_path` doesn't exists.")
+
+ # Load images.
+ dataset = load_dataset(dataset_name, split="train")
+ instance_images = [sample["image"] for sample in dataset]
+ image_hashes = [self.generate_image_hash(image) for image in instance_images]
+ self.instance_images = instance_images
+ self.image_hashes = image_hashes
+
+ # Image transformations
+ self.pixel_values = self.apply_image_transformations(
+ instance_images=instance_images, size=size, center_crop=center_crop
+ )
+
+ # Map hashes to embeddings.
+ self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path)
+
+ self.num_instance_images = len(instance_images)
+ self._length = self.num_instance_images
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ image_hash = self.image_hashes[index % self.num_instance_images]
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[image_hash]
+ example["instance_images"] = instance_image
+ example["prompt_embeds"] = prompt_embeds
+ example["pooled_prompt_embeds"] = pooled_prompt_embeds
+ example["text_ids"] = text_ids
+ return example
+
+ def apply_image_transformations(self, instance_images, size, center_crop):
+ pixel_values = []
+
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ pixel_values.append(image)
+
+ return pixel_values
+
+ def convert_to_torch_tensor(self, embeddings: list):
+ prompt_embeds = embeddings[0]
+ pooled_prompt_embeds = embeddings[1]
+ text_ids = embeddings[2]
+ prompt_embeds = np.array(prompt_embeds).reshape(self.max_sequence_length, 4096)
+ pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(768)
+ text_ids = np.array(text_ids).reshape(77, 3)
+ return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds), torch.from_numpy(text_ids)
+
+ def map_image_hash_embedding(self, data_df_path):
+ hashes_df = pd.read_parquet(data_df_path)
+ data_dict = {}
+ for i, row in hashes_df.iterrows():
+ embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"], row["text_ids"]]
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.convert_to_torch_tensor(embeddings=embeddings)
+ data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds, text_ids)})
+ return data_dict
+
+ def generate_image_hash(self, image):
+ return insecure_hashlib.sha256(image.tobytes()).hexdigest()
+
+
+def collate_fn(examples):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompt_embeds = [example["prompt_embeds"] for example in examples]
+ pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples]
+ text_ids = [example["text_ids"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompt_embeds = torch.stack(prompt_embeds)
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)
+ text_ids = torch.stack(text_ids)[0] # just 2D tensor
+
+ batch = {
+ "pixel_values": pixel_values,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "text_ids": text_ids,
+ }
+ return batch
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ bnb_4bit_compute_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ bnb_4bit_compute_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ bnb_4bit_compute_dtype = torch.bfloat16
+ if args.quantized_model_path is not None:
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.quantized_model_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=bnb_4bit_compute_dtype,
+ )
+ else:
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
+ )
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ quantization_config=nf4_config,
+ torch_dtype=bnb_4bit_compute_dtype,
+ )
+ transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ model = unwrap_model(model)
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ FluxPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=None,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+ else:
+ if args.quantized_model_path is not None:
+ transformer_ = FluxTransformer2DModel.from_pretrained(
+ args.quantized_model_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=bnb_4bit_compute_dtype,
+ )
+ else:
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
+ )
+ transformer_ = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ quantization_config=nf4_config,
+ torch_dtype=bnb_4bit_compute_dtype,
+ )
+ transformer_ = prepare_model_for_kbit_training(transformer_, use_gradient_checkpointing=False)
+ transformer_.add_adapter(transformer_lora_config)
+
+ lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
+ logger.warning(
+ f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ elif args.optimizer.lower() == "ademamix":
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
+ )
+ if args.use_8bit_ademamix:
+ optimizer_class = bnb.optim.AdEMAMix8bit
+ else:
+ optimizer_class = bnb.optim.AdEMAMix
+
+ optimizer = optimizer_class(params_to_optimize)
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ data_df_path=args.data_df_path,
+ dataset_name="Norod78/Yarn-art-style",
+ size=args.resolution,
+ max_sequence_length=args.max_sequence_length,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ vae_config_shift_factor = vae.config.shift_factor
+ vae_config_scaling_factor = vae.config.scaling_factor
+ vae_config_block_out_channels = vae.config.block_out_channels
+ if args.cache_latents:
+ latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=weight_dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+
+ del vae
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux-dev-lora-nf4"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ with accelerator.accumulate(models_to_accumulate):
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].sample()
+ else:
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
+
+ latent_image_ids = FluxPipeline._prepare_latent_image_ids(
+ model_input.shape[0],
+ model_input.shape[2] // 2,
+ model_input.shape[3] // 2,
+ accelerator.device,
+ weight_dtype,
+ )
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ packed_noisy_model_input = FluxPipeline._pack_latents(
+ noisy_model_input,
+ batch_size=model_input.shape[0],
+ num_channels_latents=model_input.shape[1],
+ height=model_input.shape[2],
+ width=model_input.shape[3],
+ )
+
+ # handle guidance
+ if unwrap_model(transformer).config.guidance_embeds:
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+ else:
+ guidance = None
+
+ # Predict the noise
+ prompt_embeds = batch["prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype)
+ text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype)
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+ model_pred = FluxPipeline._unpack_latents(
+ model_pred,
+ height=model_input.shape[2] * vae_scale_factor,
+ width=model_input.shape[3] * vae_scale_factor,
+ vae_scale_factor=vae_scale_factor,
+ )
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ FluxPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ text_encoder_lora_layers=None,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=None,
+ repo_folder=args.output_dir,
+ quantization_config=transformer.config["quantization_config"],
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/instructpix2pix_lora/README.md b/examples/research_projects/instructpix2pix_lora/README.md
index cfcd98926c07..25f7931b47d4 100644
--- a/examples/research_projects/instructpix2pix_lora/README.md
+++ b/examples/research_projects/instructpix2pix_lora/README.md
@@ -2,6 +2,34 @@
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+
## Training script example
```bash
@@ -9,7 +37,7 @@ export MODEL_ID="timbrooks/instruct-pix2pix"
export DATASET_ID="instruction-tuning-sd/cartoonization"
export OUTPUT_DIR="instructPix2Pix-cartoonization"
-accelerate launch finetune_instruct_pix2pix.py \
+accelerate launch train_instruct_pix2pix_lora.py \
--pretrained_model_name_or_path=$MODEL_ID \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
@@ -24,7 +52,10 @@ accelerate launch finetune_instruct_pix2pix.py \
--rank=4 \
--output_dir=$OUTPUT_DIR \
--report_to=wandb \
- --push_to_hub
+ --push_to_hub \
+ --original_image_column="original_image" \
+ --edited_image_column="cartoonized_image" \
+ --edit_prompt_column="edit_prompt"
```
## Inference
diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
index 997d448fa281..070cdad15564 100644
--- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
+++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
+"""
+ Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
+ Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
+"""
import argparse
import logging
@@ -30,6 +33,7 @@
import PIL
import requests
import torch
+import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
@@ -39,21 +43,28 @@
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
-from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
-from diffusers.training_utils import EMAModel
-from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.training_utils import EMAModel, cast_training_params
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.26.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -63,6 +74,92 @@
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "text-to-image",
+ "instruct-pix2pix",
+ "diffusers",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ original_image = download_image(args.val_image_url)
+ edited_images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
+ tracker.log({"validation": wandb_table})
+
+ return edited_images
+
+
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
@@ -417,11 +514,6 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
- if args.report_to == "wandb":
- if not is_wandb_available():
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
- import wandb
-
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -467,49 +559,58 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
+ # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
+ # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
+ # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
+ # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
+ # initialized to zero.
+ logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
+ in_channels = 8
+ out_channels = unet.conv_in.out_channels
+ unet.register_to_config(in_channels=in_channels)
+
+ with torch.no_grad():
+ new_conv_in = nn.Conv2d(
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
+ )
+ new_conv_in.weight.zero_()
+ new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)
+ unet.conv_in = new_conv_in
+
# Freeze vae, text_encoder and unet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
- unet_lora_parameters = []
- for attn_processor_name, attn_processor in unet.attn_processors.items():
- # Parse the attention module.
- attn_module = unet
- for n in attn_processor_name.split(".")[:-1]:
- attn_module = getattr(attn_module, n)
-
- # Set the `lora_layer` attribute of the attention-related matrices.
- attn_module.to_q.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
- )
- )
- attn_module.to_k.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
- )
- )
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
- attn_module.to_v.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
- )
- )
- attn_module.to_out[0].set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_out[0].in_features,
- out_features=attn_module.to_out[0].out_features,
- rank=args.rank,
- )
- )
+ # Freeze the unet parameters before adding adapters
+ unet.requires_grad_(False)
- # Accumulate the LoRA params to optimize.
- unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(unet, dtype=torch.float32)
# Create EMA for the unet.
if args.use_ema:
@@ -528,6 +629,13 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
+ trainable_params = filter(lambda p: p.requires_grad, unet.parameters())
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -540,7 +648,8 @@ def save_model_hook(models, weights, output_dir):
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
- weights.pop()
+ if weights:
+ weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
@@ -589,9 +698,9 @@ def load_model_hook(models, input_dir):
else:
optimizer_cls = torch.optim.AdamW
- # train on only unet_lora_parameters
+ # train on only lora_layers
optimizer = optimizer_cls(
- unet_lora_parameters,
+ trainable_params,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -730,22 +839,27 @@ def collate_fn(examples):
)
# Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
)
# Prepare everything with our `accelerator`.
- unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
)
if args.use_ema:
@@ -765,8 +879,14 @@ def collate_fn(examples):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
+ if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -885,7 +1005,7 @@ def collate_fn(examples):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
- model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
+ model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
@@ -895,7 +1015,7 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
- accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
+ accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -903,7 +1023,7 @@ def collate_fn(examples):
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
- ema_unet.step(unet_lora_parameters)
+ ema_unet.step(trainable_params)
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -933,6 +1053,16 @@ def collate_fn(examples):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(unwrapped_unet)
+ )
+
+ StableDiffusionInstructPix2PixPipeline.save_lora_weights(
+ save_directory=save_path,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -959,45 +1089,22 @@ def collate_fn(examples):
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
- unet=accelerator.unwrap_model(unet),
- text_encoder=accelerator.unwrap_model(text_encoder),
- vae=accelerator.unwrap_model(vae),
+ unet=unwrap_model(unet),
+ text_encoder=unwrap_model(text_encoder),
+ vae=unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
- pipeline = pipeline.to(accelerator.device)
- pipeline.set_progress_bar_config(disable=True)
# run inference
- original_image = download_image(args.val_image_url)
- edited_images = []
- if torch.backends.mps.is_available():
- autocast_ctx = nullcontext()
- else:
- autocast_ctx = torch.autocast(accelerator.device.type)
-
- with autocast_ctx:
- for _ in range(args.num_validation_images):
- edited_images.append(
- pipeline(
- args.validation_prompt,
- image=original_image,
- num_inference_steps=20,
- image_guidance_scale=1.5,
- guidance_scale=7,
- generator=generator,
- ).images[0]
- )
+ log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ )
- for tracker in accelerator.trackers:
- if tracker.name == "wandb":
- wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
- for edited_image in edited_images:
- wandb_table.add_data(
- wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
- )
- tracker.log({"validation": wandb_table})
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
@@ -1008,22 +1115,47 @@ def collate_fn(examples):
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
- unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
+ # store only LORA layers
+ unet = unet.to(torch.float32)
+
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
+ StableDiffusionInstructPix2PixPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
+
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
- text_encoder=accelerator.unwrap_model(text_encoder),
- vae=accelerator.unwrap_model(vae),
- unet=unet,
+ text_encoder=unwrap_model(text_encoder),
+ vae=unwrap_model(vae),
+ unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
)
- # store only LORA layers
- unet.save_attn_procs(args.output_dir)
+ pipeline.load_lora_weights(args.output_dir)
+
+ images = None
+ if (args.val_image_url is not None) and (args.validation_prompt is not None):
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ )
if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
@@ -1031,31 +1163,6 @@ def collate_fn(examples):
ignore_patterns=["step_*", "epoch_*"],
)
- if args.validation_prompt is not None:
- edited_images = []
- pipeline = pipeline.to(accelerator.device)
- with torch.autocast(str(accelerator.device).replace(":0", "")):
- for _ in range(args.num_validation_images):
- edited_images.append(
- pipeline(
- args.validation_prompt,
- image=original_image,
- num_inference_steps=20,
- image_guidance_scale=1.5,
- guidance_scale=7,
- generator=generator,
- ).images[0]
- )
-
- for tracker in accelerator.trackers:
- if tracker.name == "wandb":
- wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
- for edited_image in edited_images:
- wandb_table.add_data(
- wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
- )
- tracker.log({"test": wandb_table})
-
accelerator.end_training()
diff --git a/examples/research_projects/ip_adapter/README.md b/examples/research_projects/ip_adapter/README.md
new file mode 100644
index 000000000000..04a6c86e5305
--- /dev/null
+++ b/examples/research_projects/ip_adapter/README.md
@@ -0,0 +1,226 @@
+# IP Adapter Training Example
+
+[IP Adapter](https://arxiv.org/abs/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.
+
+## Training locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+Certainly! Below is the documentation in pure Markdown format:
+
+### Accelerate Launch Command Documentation
+
+#### Description:
+The Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations.
+
+#### Usage Example:
+
+```
+accelerate launch --mixed_precision "fp16" \
+tutorial_train_ip-adapter.py \
+--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
+--image_encoder_path="{image_encoder_path}" \
+--data_json_file="{data.json}" \
+--data_root_path="{image_path}" \
+--mixed_precision="fp16" \
+--resolution=512 \
+--train_batch_size=8 \
+--dataloader_num_workers=4 \
+--learning_rate=1e-04 \
+--weight_decay=0.01 \
+--output_dir="{output_dir}" \
+--save_steps=10000
+```
+
+### Multi-GPU Script:
+```
+accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
+ tutorial_train_ip-adapter.py \
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
+ --image_encoder_path="{image_encoder_path}" \
+ --data_json_file="{data.json}" \
+ --data_root_path="{image_path}" \
+ --mixed_precision="fp16" \
+ --resolution=512 \
+ --train_batch_size=8 \
+ --dataloader_num_workers=4 \
+ --learning_rate=1e-04 \
+ --weight_decay=0.01 \
+ --output_dir="{output_dir}" \
+ --save_steps=10000
+```
+
+#### Parameters:
+- `--num_processes`: Number of processes to launch for distributed training (in this example, 8 processes).
+- `--multi_gpu`: Flag indicating the usage of multiple GPUs for training.
+- `--mixed_precision "fp16"`: Enables mixed precision training with 16-bit floating-point precision.
+- `tutorial_train_ip-adapter.py`: Name of the training script to be executed.
+- `--pretrained_model_name_or_path`: Path or identifier for a pretrained model.
+- `--image_encoder_path`: Path to the CLIP image encoder.
+- `--data_json_file`: Path to the training data in JSON format.
+- `--data_root_path`: Root path where training images are located.
+- `--resolution`: Resolution of input images (512x512 in this example).
+- `--train_batch_size`: Batch size for training data (8 in this example).
+- `--dataloader_num_workers`: Number of subprocesses for data loading (4 in this example).
+- `--learning_rate`: Learning rate for training (1e-04 in this example).
+- `--weight_decay`: Weight decay for regularization (0.01 in this example).
+- `--output_dir`: Directory to save model checkpoints and predictions.
+- `--save_steps`: Frequency of saving checkpoints during training (10000 in this example).
+
+### Inference
+
+#### Description:
+The provided inference code is used to load a trained model checkpoint and extract the components related to image projection and IP (Image Processing) adapter. These components are then saved into a binary file for later use in inference.
+
+#### Usage Example:
+```python
+from safetensors.torch import load_file, save_file
+
+# Load the trained model checkpoint in safetensors format
+ckpt = "checkpoint-50000/pytorch_model.safetensors"
+sd = load_file(ckpt) # Using safetensors load function
+
+# Extract image projection and IP adapter components
+image_proj_sd = {}
+ip_sd = {}
+
+for k in sd:
+ if k.startswith("unet"):
+ pass # Skip unet-related keys
+ elif k.startswith("image_proj_model"):
+ image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
+ elif k.startswith("adapter_modules"):
+ ip_sd[k.replace("adapter_modules.", "")] = sd[k]
+
+# Save the components into separate safetensors files
+save_file(image_proj_sd, "image_proj.safetensors")
+save_file(ip_sd, "ip_adapter.safetensors")
+```
+
+### Sample Inference Script using the CLIP Model
+
+```python
+
+import torch
+from safetensors.torch import load_file
+from transformers import CLIPProcessor, CLIPModel # Using the Hugging Face CLIP model
+
+# Load model components from safetensors
+image_proj_ckpt = "image_proj.safetensors"
+ip_adapter_ckpt = "ip_adapter.safetensors"
+
+# Load the saved weights
+image_proj_sd = load_file(image_proj_ckpt)
+ip_adapter_sd = load_file(ip_adapter_ckpt)
+
+# Define the model Parameters
+class ImageProjectionModel(torch.nn.Module):
+ def __init__(self, input_dim=768, output_dim=512): # CLIP's default embedding size is 768
+ super().__init__()
+ self.model = torch.nn.Linear(input_dim, output_dim)
+
+ def forward(self, x):
+ return self.model(x)
+
+class IPAdapterModel(torch.nn.Module):
+ def __init__(self, input_dim=512, output_dim=10): # Example for 10 classes
+ super().__init__()
+ self.model = torch.nn.Linear(input_dim, output_dim)
+
+ def forward(self, x):
+ return self.model(x)
+
+# Initialize models
+image_proj_model = ImageProjectionModel()
+ip_adapter_model = IPAdapterModel()
+
+# Load weights into models
+image_proj_model.load_state_dict(image_proj_sd)
+ip_adapter_model.load_state_dict(ip_adapter_sd)
+
+# Set models to evaluation mode
+image_proj_model.eval()
+ip_adapter_model.eval()
+
+#Inference pipeline
+def inference(image_tensor):
+ """
+ Run inference using the loaded models.
+
+ Args:
+ image_tensor: Preprocessed image tensor from CLIPProcessor
+
+ Returns:
+ Final inference results
+ """
+ with torch.no_grad():
+ # Step 1: Project the image features
+ image_proj = image_proj_model(image_tensor)
+
+ # Step 2: Pass the projected features through the IP Adapter
+ result = ip_adapter_model(image_proj)
+
+ return result
+
+# Using CLIP for image preprocessing
+processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+
+#Image file path
+image_path = "path/to/image.jpg"
+
+# Preprocess the image
+inputs = processor(images=image_path, return_tensors="pt")
+image_features = clip_model.get_image_features(inputs["pixel_values"])
+
+# Normalize the image features as per CLIP's recommendations
+image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+
+# Run inference
+output = inference(image_features)
+print("Inference output:", output)
+```
+
+#### Parameters:
+- `ckpt`: Path to the trained model checkpoint file.
+- `map_location="cpu"`: Specifies that the model should be loaded onto the CPU.
+- `image_proj_sd`: Dictionary to store the components related to image projection.
+- `ip_sd`: Dictionary to store the components related to the IP adapter.
+- `"unet"`, `"image_proj_model"`, `"adapter_modules"`: Prefixes indicating components of the model.
\ No newline at end of file
diff --git a/examples/research_projects/ip_adapter/requirements.txt b/examples/research_projects/ip_adapter/requirements.txt
new file mode 100644
index 000000000000..749aa795015d
--- /dev/null
+++ b/examples/research_projects/ip_adapter/requirements.txt
@@ -0,0 +1,4 @@
+accelerate
+torchvision
+transformers>=4.25.1
+ip_adapter
diff --git a/examples/research_projects/ip_adapter/tutorial_train_faceid.py b/examples/research_projects/ip_adapter/tutorial_train_faceid.py
new file mode 100644
index 000000000000..3e337ec02f7f
--- /dev/null
+++ b/examples/research_projects/ip_adapter/tutorial_train_faceid.py
@@ -0,0 +1,415 @@
+import argparse
+import itertools
+import json
+import os
+import random
+import time
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.utils import ProjectConfiguration
+from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
+from ip_adapter.ip_adapter_faceid import MLPProjModel
+from PIL import Image
+from torchvision import transforms
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+
+
+# Dataset
+class MyDataset(torch.utils.data.Dataset):
+ def __init__(
+ self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
+ ):
+ super().__init__()
+
+ self.tokenizer = tokenizer
+ self.size = size
+ self.i_drop_rate = i_drop_rate
+ self.t_drop_rate = t_drop_rate
+ self.ti_drop_rate = ti_drop_rate
+ self.image_root_path = image_root_path
+
+ self.data = json.load(
+ open(json_file)
+ ) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}]
+
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(self.size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __getitem__(self, idx):
+ item = self.data[idx]
+ text = item["text"]
+ image_file = item["image_file"]
+
+ # read image
+ raw_image = Image.open(os.path.join(self.image_root_path, image_file))
+ image = self.transform(raw_image.convert("RGB"))
+
+ face_id_embed = torch.load(item["id_embed_file"], map_location="cpu")
+ face_id_embed = torch.from_numpy(face_id_embed)
+
+ # drop
+ drop_image_embed = 0
+ rand_num = random.random()
+ if rand_num < self.i_drop_rate:
+ drop_image_embed = 1
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate):
+ text = ""
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
+ text = ""
+ drop_image_embed = 1
+ if drop_image_embed:
+ face_id_embed = torch.zeros_like(face_id_embed)
+ # get text and tokenize
+ text_input_ids = self.tokenizer(
+ text,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+
+ return {
+ "image": image,
+ "text_input_ids": text_input_ids,
+ "face_id_embed": face_id_embed,
+ "drop_image_embed": drop_image_embed,
+ }
+
+ def __len__(self):
+ return len(self.data)
+
+
+def collate_fn(data):
+ images = torch.stack([example["image"] for example in data])
+ text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
+ face_id_embed = torch.stack([example["face_id_embed"] for example in data])
+ drop_image_embeds = [example["drop_image_embed"] for example in data]
+
+ return {
+ "images": images,
+ "text_input_ids": text_input_ids,
+ "face_id_embed": face_id_embed,
+ "drop_image_embeds": drop_image_embeds,
+ }
+
+
+class IPAdapter(torch.nn.Module):
+ """IP-Adapter"""
+
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
+ super().__init__()
+ self.unet = unet
+ self.image_proj_model = image_proj_model
+ self.adapter_modules = adapter_modules
+
+ if ckpt_path is not None:
+ self.load_from_checkpoint(ckpt_path)
+
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
+ ip_tokens = self.image_proj_model(image_embeds)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ return noise_pred
+
+ def load_from_checkpoint(self, ckpt_path: str):
+ # Calculate original checksums
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+
+ # Load state dict for image_proj_model and adapter_modules
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
+
+ # Calculate new checksums
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+ # Verify if the weights have changed
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
+
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_ip_adapter_path",
+ type=str,
+ default=None,
+ help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
+ )
+ parser.add_argument(
+ "--data_json_file",
+ type=str,
+ default=None,
+ required=True,
+ help="Training data",
+ )
+ parser.add_argument(
+ "--data_root_path",
+ type=str,
+ default="",
+ required=True,
+ help="Training data root path",
+ )
+ parser.add_argument(
+ "--image_encoder_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to CLIP image encoder",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-ip_adapter",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=("The resolution for input images"),
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Learning rate to use.",
+ )
+ parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=2000,
+ help=("Save a checkpoint of the training state every X updates"),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+ # image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ # image_encoder.requires_grad_(False)
+
+ # ip-adapter
+ image_proj_model = MLPProjModel(
+ cross_attention_dim=unet.config.cross_attention_dim,
+ id_embeddings_dim=512,
+ num_tokens=4,
+ )
+ # init adapter modules
+ lora_rank = 128
+ attn_procs = {}
+ unet_sd = unet.state_dict()
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = LoRAAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
+ )
+ else:
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
+ }
+ attn_procs[name] = LoRAIPAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
+ )
+ attn_procs[name].load_state_dict(weights, strict=False)
+ unet.set_attn_processor(attn_procs)
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
+
+ ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ # unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ # image_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # optimizer
+ params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
+ optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
+
+ # dataloader
+ train_dataset = MyDataset(
+ args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Prepare everything with our `accelerator`.
+ ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
+
+ global_step = 0
+ for epoch in range(0, args.num_train_epochs):
+ begin = time.perf_counter()
+ for step, batch in enumerate(train_dataloader):
+ load_data_time = time.perf_counter() - begin
+ with accelerator.accumulate(ip_adapter):
+ # Convert images to latent space
+ with torch.no_grad():
+ latents = vae.encode(
+ batch["images"].to(accelerator.device, dtype=weight_dtype)
+ ).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype)
+
+ with torch.no_grad():
+ encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
+
+ noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
+
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
+
+ # Backpropagate
+ accelerator.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if accelerator.is_main_process:
+ print(
+ "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
+ epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
+ )
+ )
+
+ global_step += 1
+
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+
+ begin = time.perf_counter()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py b/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py
new file mode 100644
index 000000000000..9a3513f4c549
--- /dev/null
+++ b/examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py
@@ -0,0 +1,422 @@
+import argparse
+import itertools
+import json
+import os
+import random
+import time
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.utils import ProjectConfiguration
+from ip_adapter.ip_adapter import ImageProjModel
+from ip_adapter.utils import is_torch2_available
+from PIL import Image
+from torchvision import transforms
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+
+
+if is_torch2_available():
+ from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
+ from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
+else:
+ from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
+
+
+# Dataset
+class MyDataset(torch.utils.data.Dataset):
+ def __init__(
+ self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
+ ):
+ super().__init__()
+
+ self.tokenizer = tokenizer
+ self.size = size
+ self.i_drop_rate = i_drop_rate
+ self.t_drop_rate = t_drop_rate
+ self.ti_drop_rate = ti_drop_rate
+ self.image_root_path = image_root_path
+
+ self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
+
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(self.size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ self.clip_image_processor = CLIPImageProcessor()
+
+ def __getitem__(self, idx):
+ item = self.data[idx]
+ text = item["text"]
+ image_file = item["image_file"]
+
+ # read image
+ raw_image = Image.open(os.path.join(self.image_root_path, image_file))
+ image = self.transform(raw_image.convert("RGB"))
+ clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
+
+ # drop
+ drop_image_embed = 0
+ rand_num = random.random()
+ if rand_num < self.i_drop_rate:
+ drop_image_embed = 1
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate):
+ text = ""
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
+ text = ""
+ drop_image_embed = 1
+ # get text and tokenize
+ text_input_ids = self.tokenizer(
+ text,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+
+ return {
+ "image": image,
+ "text_input_ids": text_input_ids,
+ "clip_image": clip_image,
+ "drop_image_embed": drop_image_embed,
+ }
+
+ def __len__(self):
+ return len(self.data)
+
+
+def collate_fn(data):
+ images = torch.stack([example["image"] for example in data])
+ text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
+ clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
+ drop_image_embeds = [example["drop_image_embed"] for example in data]
+
+ return {
+ "images": images,
+ "text_input_ids": text_input_ids,
+ "clip_images": clip_images,
+ "drop_image_embeds": drop_image_embeds,
+ }
+
+
+class IPAdapter(torch.nn.Module):
+ """IP-Adapter"""
+
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
+ super().__init__()
+ self.unet = unet
+ self.image_proj_model = image_proj_model
+ self.adapter_modules = adapter_modules
+
+ if ckpt_path is not None:
+ self.load_from_checkpoint(ckpt_path)
+
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
+ ip_tokens = self.image_proj_model(image_embeds)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ return noise_pred
+
+ def load_from_checkpoint(self, ckpt_path: str):
+ # Calculate original checksums
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+
+ # Load state dict for image_proj_model and adapter_modules
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
+
+ # Calculate new checksums
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+ # Verify if the weights have changed
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
+
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_ip_adapter_path",
+ type=str,
+ default=None,
+ help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
+ )
+ parser.add_argument(
+ "--data_json_file",
+ type=str,
+ default=None,
+ required=True,
+ help="Training data",
+ )
+ parser.add_argument(
+ "--data_root_path",
+ type=str,
+ default="",
+ required=True,
+ help="Training data root path",
+ )
+ parser.add_argument(
+ "--image_encoder_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to CLIP image encoder",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-ip_adapter",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=("The resolution for input images"),
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Learning rate to use.",
+ )
+ parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=2000,
+ help=("Save a checkpoint of the training state every X updates"),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # ip-adapter
+ image_proj_model = ImageProjModel(
+ cross_attention_dim=unet.config.cross_attention_dim,
+ clip_embeddings_dim=image_encoder.config.projection_dim,
+ clip_extra_context_tokens=4,
+ )
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = unet.state_dict()
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor()
+ else:
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
+ }
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
+ attn_procs[name].load_state_dict(weights)
+ unet.set_attn_processor(attn_procs)
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
+
+ ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ # unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # optimizer
+ params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
+ optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
+
+ # dataloader
+ train_dataset = MyDataset(
+ args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Prepare everything with our `accelerator`.
+ ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
+
+ global_step = 0
+ for epoch in range(0, args.num_train_epochs):
+ begin = time.perf_counter()
+ for step, batch in enumerate(train_dataloader):
+ load_data_time = time.perf_counter() - begin
+ with accelerator.accumulate(ip_adapter):
+ # Convert images to latent space
+ with torch.no_grad():
+ latents = vae.encode(
+ batch["images"].to(accelerator.device, dtype=weight_dtype)
+ ).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ with torch.no_grad():
+ image_embeds = image_encoder(
+ batch["clip_images"].to(accelerator.device, dtype=weight_dtype)
+ ).image_embeds
+ image_embeds_ = []
+ for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
+ if drop_image_embed == 1:
+ image_embeds_.append(torch.zeros_like(image_embed))
+ else:
+ image_embeds_.append(image_embed)
+ image_embeds = torch.stack(image_embeds_)
+
+ with torch.no_grad():
+ encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
+
+ noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
+
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
+
+ # Backpropagate
+ accelerator.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if accelerator.is_main_process:
+ print(
+ "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
+ epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
+ )
+ )
+
+ global_step += 1
+
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+
+ begin = time.perf_counter()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/ip_adapter/tutorial_train_plus.py b/examples/research_projects/ip_adapter/tutorial_train_plus.py
new file mode 100644
index 000000000000..e777ea1f0047
--- /dev/null
+++ b/examples/research_projects/ip_adapter/tutorial_train_plus.py
@@ -0,0 +1,445 @@
+import argparse
+import itertools
+import json
+import os
+import random
+import time
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.utils import ProjectConfiguration
+from ip_adapter.resampler import Resampler
+from ip_adapter.utils import is_torch2_available
+from PIL import Image
+from torchvision import transforms
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+
+
+if is_torch2_available():
+ from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
+ from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
+else:
+ from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
+
+
+# Dataset
+class MyDataset(torch.utils.data.Dataset):
+ def __init__(
+ self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
+ ):
+ super().__init__()
+
+ self.tokenizer = tokenizer
+ self.size = size
+ self.i_drop_rate = i_drop_rate
+ self.t_drop_rate = t_drop_rate
+ self.ti_drop_rate = ti_drop_rate
+ self.image_root_path = image_root_path
+
+ self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
+
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(self.size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ self.clip_image_processor = CLIPImageProcessor()
+
+ def __getitem__(self, idx):
+ item = self.data[idx]
+ text = item["text"]
+ image_file = item["image_file"]
+
+ # read image
+ raw_image = Image.open(os.path.join(self.image_root_path, image_file))
+ image = self.transform(raw_image.convert("RGB"))
+ clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
+
+ # drop
+ drop_image_embed = 0
+ rand_num = random.random()
+ if rand_num < self.i_drop_rate:
+ drop_image_embed = 1
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate):
+ text = ""
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
+ text = ""
+ drop_image_embed = 1
+ # get text and tokenize
+ text_input_ids = self.tokenizer(
+ text,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+
+ return {
+ "image": image,
+ "text_input_ids": text_input_ids,
+ "clip_image": clip_image,
+ "drop_image_embed": drop_image_embed,
+ }
+
+ def __len__(self):
+ return len(self.data)
+
+
+def collate_fn(data):
+ images = torch.stack([example["image"] for example in data])
+ text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
+ clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
+ drop_image_embeds = [example["drop_image_embed"] for example in data]
+
+ return {
+ "images": images,
+ "text_input_ids": text_input_ids,
+ "clip_images": clip_images,
+ "drop_image_embeds": drop_image_embeds,
+ }
+
+
+class IPAdapter(torch.nn.Module):
+ """IP-Adapter"""
+
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
+ super().__init__()
+ self.unet = unet
+ self.image_proj_model = image_proj_model
+ self.adapter_modules = adapter_modules
+
+ if ckpt_path is not None:
+ self.load_from_checkpoint(ckpt_path)
+
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
+ ip_tokens = self.image_proj_model(image_embeds)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ return noise_pred
+
+ def load_from_checkpoint(self, ckpt_path: str):
+ # Calculate original checksums
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+
+ # Check if 'latents' exists in both the saved state_dict and the current model's state_dict
+ strict_load_image_proj_model = True
+ if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
+ # Check if the shapes are mismatched
+ if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
+ print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
+ print("Removing 'latents' from checkpoint and loading the rest of the weights.")
+ del state_dict["image_proj"]["latents"]
+ strict_load_image_proj_model = False
+
+ # Load state dict for image_proj_model and adapter_modules
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
+
+ # Calculate new checksums
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+ # Verify if the weights have changed
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
+
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_ip_adapter_path",
+ type=str,
+ default=None,
+ help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
+ )
+ parser.add_argument(
+ "--num_tokens",
+ type=int,
+ default=16,
+ help="Number of tokens to query from the CLIP image encoding.",
+ )
+ parser.add_argument(
+ "--data_json_file",
+ type=str,
+ default=None,
+ required=True,
+ help="Training data",
+ )
+ parser.add_argument(
+ "--data_root_path",
+ type=str,
+ default="",
+ required=True,
+ help="Training data root path",
+ )
+ parser.add_argument(
+ "--image_encoder_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to CLIP image encoder",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-ip_adapter",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=("The resolution for input images"),
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Learning rate to use.",
+ )
+ parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=2000,
+ help=("Save a checkpoint of the training state every X updates"),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # ip-adapter-plus
+ image_proj_model = Resampler(
+ dim=unet.config.cross_attention_dim,
+ depth=4,
+ dim_head=64,
+ heads=12,
+ num_queries=args.num_tokens,
+ embedding_dim=image_encoder.config.hidden_size,
+ output_dim=unet.config.cross_attention_dim,
+ ff_mult=4,
+ )
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = unet.state_dict()
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor()
+ else:
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
+ }
+ attn_procs[name] = IPAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens
+ )
+ attn_procs[name].load_state_dict(weights)
+ unet.set_attn_processor(attn_procs)
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
+
+ ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ # unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # optimizer
+ params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
+ optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
+
+ # dataloader
+ train_dataset = MyDataset(
+ args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Prepare everything with our `accelerator`.
+ ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
+
+ global_step = 0
+ for epoch in range(0, args.num_train_epochs):
+ begin = time.perf_counter()
+ for step, batch in enumerate(train_dataloader):
+ load_data_time = time.perf_counter() - begin
+ with accelerator.accumulate(ip_adapter):
+ # Convert images to latent space
+ with torch.no_grad():
+ latents = vae.encode(
+ batch["images"].to(accelerator.device, dtype=weight_dtype)
+ ).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ clip_images = []
+ for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
+ if drop_image_embed == 1:
+ clip_images.append(torch.zeros_like(clip_image))
+ else:
+ clip_images.append(clip_image)
+ clip_images = torch.stack(clip_images, dim=0)
+ with torch.no_grad():
+ image_embeds = image_encoder(
+ clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True
+ ).hidden_states[-2]
+
+ with torch.no_grad():
+ encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
+
+ noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
+
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
+
+ # Backpropagate
+ accelerator.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if accelerator.is_main_process:
+ print(
+ "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
+ epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
+ )
+ )
+
+ global_step += 1
+
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+
+ begin = time.perf_counter()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/ip_adapter/tutorial_train_sdxl.py b/examples/research_projects/ip_adapter/tutorial_train_sdxl.py
new file mode 100644
index 000000000000..cd7dffe13a80
--- /dev/null
+++ b/examples/research_projects/ip_adapter/tutorial_train_sdxl.py
@@ -0,0 +1,520 @@
+import argparse
+import itertools
+import json
+import os
+import random
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.utils import ProjectConfiguration
+from ip_adapter.ip_adapter import ImageProjModel
+from ip_adapter.utils import is_torch2_available
+from PIL import Image
+from torchvision import transforms
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+
+
+if is_torch2_available():
+ from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
+ from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
+else:
+ from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
+
+
+# Dataset
+class MyDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ json_file,
+ tokenizer,
+ tokenizer_2,
+ size=1024,
+ center_crop=True,
+ t_drop_rate=0.05,
+ i_drop_rate=0.05,
+ ti_drop_rate=0.05,
+ image_root_path="",
+ ):
+ super().__init__()
+
+ self.tokenizer = tokenizer
+ self.tokenizer_2 = tokenizer_2
+ self.size = size
+ self.center_crop = center_crop
+ self.i_drop_rate = i_drop_rate
+ self.t_drop_rate = t_drop_rate
+ self.ti_drop_rate = ti_drop_rate
+ self.image_root_path = image_root_path
+
+ self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
+
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ self.clip_image_processor = CLIPImageProcessor()
+
+ def __getitem__(self, idx):
+ item = self.data[idx]
+ text = item["text"]
+ image_file = item["image_file"]
+
+ # read image
+ raw_image = Image.open(os.path.join(self.image_root_path, image_file))
+
+ # original size
+ original_width, original_height = raw_image.size
+ original_size = torch.tensor([original_height, original_width])
+
+ image_tensor = self.transform(raw_image.convert("RGB"))
+ # random crop
+ delta_h = image_tensor.shape[1] - self.size
+ delta_w = image_tensor.shape[2] - self.size
+ assert not all([delta_h, delta_w])
+
+ if self.center_crop:
+ top = delta_h // 2
+ left = delta_w // 2
+ else:
+ top = np.random.randint(0, delta_h + 1)
+ left = np.random.randint(0, delta_w + 1)
+ image = transforms.functional.crop(image_tensor, top=top, left=left, height=self.size, width=self.size)
+ crop_coords_top_left = torch.tensor([top, left])
+
+ clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
+
+ # drop
+ drop_image_embed = 0
+ rand_num = random.random()
+ if rand_num < self.i_drop_rate:
+ drop_image_embed = 1
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate):
+ text = ""
+ elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
+ text = ""
+ drop_image_embed = 1
+
+ # get text and tokenize
+ text_input_ids = self.tokenizer(
+ text,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+
+ text_input_ids_2 = self.tokenizer_2(
+ text,
+ max_length=self.tokenizer_2.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+
+ return {
+ "image": image,
+ "text_input_ids": text_input_ids,
+ "text_input_ids_2": text_input_ids_2,
+ "clip_image": clip_image,
+ "drop_image_embed": drop_image_embed,
+ "original_size": original_size,
+ "crop_coords_top_left": crop_coords_top_left,
+ "target_size": torch.tensor([self.size, self.size]),
+ }
+
+ def __len__(self):
+ return len(self.data)
+
+
+def collate_fn(data):
+ images = torch.stack([example["image"] for example in data])
+ text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
+ text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0)
+ clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
+ drop_image_embeds = [example["drop_image_embed"] for example in data]
+ original_size = torch.stack([example["original_size"] for example in data])
+ crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data])
+ target_size = torch.stack([example["target_size"] for example in data])
+
+ return {
+ "images": images,
+ "text_input_ids": text_input_ids,
+ "text_input_ids_2": text_input_ids_2,
+ "clip_images": clip_images,
+ "drop_image_embeds": drop_image_embeds,
+ "original_size": original_size,
+ "crop_coords_top_left": crop_coords_top_left,
+ "target_size": target_size,
+ }
+
+
+class IPAdapter(torch.nn.Module):
+ """IP-Adapter"""
+
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
+ super().__init__()
+ self.unet = unet
+ self.image_proj_model = image_proj_model
+ self.adapter_modules = adapter_modules
+
+ if ckpt_path is not None:
+ self.load_from_checkpoint(ckpt_path)
+
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
+ ip_tokens = self.image_proj_model(image_embeds)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+ # Predict the noise residual
+ noise_pred = self.unet(
+ noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs
+ ).sample
+ return noise_pred
+
+ def load_from_checkpoint(self, ckpt_path: str):
+ # Calculate original checksums
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+
+ # Load state dict for image_proj_model and adapter_modules
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
+
+ # Calculate new checksums
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
+
+ # Verify if the weights have changed
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
+
+ print(f"Successfully loaded weights from checkpoint {ckpt_path}")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_ip_adapter_path",
+ type=str,
+ default=None,
+ help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
+ )
+ parser.add_argument(
+ "--data_json_file",
+ type=str,
+ default=None,
+ required=True,
+ help="Training data",
+ )
+ parser.add_argument(
+ "--data_root_path",
+ type=str,
+ default="",
+ required=True,
+ help="Training data root path",
+ )
+ parser.add_argument(
+ "--image_encoder_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to CLIP image encoder",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-ip_adapter",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=("The resolution for input images"),
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Learning rate to use.",
+ )
+ parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=2000,
+ help=("Save a checkpoint of the training state every X updates"),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2"
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ text_encoder_2.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # ip-adapter
+ num_tokens = 4
+ image_proj_model = ImageProjModel(
+ cross_attention_dim=unet.config.cross_attention_dim,
+ clip_embeddings_dim=image_encoder.config.projection_dim,
+ clip_extra_context_tokens=num_tokens,
+ )
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = unet.state_dict()
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor()
+ else:
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
+ }
+ attn_procs[name] = IPAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens
+ )
+ attn_procs[name].load_state_dict(weights)
+ unet.set_attn_processor(attn_procs)
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
+
+ ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ # unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device) # use fp32
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # optimizer
+ params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
+ optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
+
+ # dataloader
+ train_dataset = MyDataset(
+ args.data_json_file,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ size=args.resolution,
+ image_root_path=args.data_root_path,
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Prepare everything with our `accelerator`.
+ ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
+
+ global_step = 0
+ for epoch in range(0, args.num_train_epochs):
+ begin = time.perf_counter()
+ for step, batch in enumerate(train_dataloader):
+ load_data_time = time.perf_counter() - begin
+ with accelerator.accumulate(ip_adapter):
+ # Convert images to latent space
+ with torch.no_grad():
+ # vae of sdxl should use fp32
+ latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae.dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+ latents = latents.to(accelerator.device, dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(
+ accelerator.device, dtype=weight_dtype
+ )
+
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ with torch.no_grad():
+ image_embeds = image_encoder(
+ batch["clip_images"].to(accelerator.device, dtype=weight_dtype)
+ ).image_embeds
+ image_embeds_ = []
+ for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
+ if drop_image_embed == 1:
+ image_embeds_.append(torch.zeros_like(image_embed))
+ else:
+ image_embeds_.append(image_embed)
+ image_embeds = torch.stack(image_embeds_)
+
+ with torch.no_grad():
+ encoder_output = text_encoder(
+ batch["text_input_ids"].to(accelerator.device), output_hidden_states=True
+ )
+ text_embeds = encoder_output.hidden_states[-2]
+ encoder_output_2 = text_encoder_2(
+ batch["text_input_ids_2"].to(accelerator.device), output_hidden_states=True
+ )
+ pooled_text_embeds = encoder_output_2[0]
+ text_embeds_2 = encoder_output_2.hidden_states[-2]
+ text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat
+
+ # add cond
+ add_time_ids = [
+ batch["original_size"].to(accelerator.device),
+ batch["crop_coords_top_left"].to(accelerator.device),
+ batch["target_size"].to(accelerator.device),
+ ]
+ add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)
+ unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
+
+ noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds)
+
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
+
+ # Backpropagate
+ accelerator.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if accelerator.is_main_process:
+ print(
+ "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
+ epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
+ )
+ )
+
+ global_step += 1
+
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+
+ begin = time.perf_counter()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py
index 1ebc1422b064..a734c50d8ee0 100644
--- a/examples/research_projects/lora/train_text_to_image_lora.py
+++ b/examples/research_projects/lora/train_text_to_image_lora.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
index 57ad77477b0d..19432142f541 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
index 126a10b4f9e9..a886f9ab27ef 100644
--- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
index e10564fa59ef..7f5dc8ece9fc 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
+++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/pixart/.gitignore b/examples/research_projects/pixart/.gitignore
new file mode 100644
index 000000000000..4be0fcb237f5
--- /dev/null
+++ b/examples/research_projects/pixart/.gitignore
@@ -0,0 +1,2 @@
+images/
+output/
\ No newline at end of file
diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py
new file mode 100644
index 000000000000..8f2eb974398d
--- /dev/null
+++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py
@@ -0,0 +1,291 @@
+from typing import Any, Dict, Optional
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models import PixArtTransformer2DModel
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+
+
+class PixArtControlNetAdapterBlock(nn.Module):
+ def __init__(
+ self,
+ block_index,
+ # taken from PixArtTransformer2DModel
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 72,
+ dropout: float = 0.0,
+ cross_attention_dim: Optional[int] = 1152,
+ attention_bias: bool = True,
+ activation_fn: str = "gelu-approximate",
+ num_embeds_ada_norm: Optional[int] = 1000,
+ upcast_attention: bool = False,
+ norm_type: str = "ada_norm_single",
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ attention_type: Optional[str] = "default",
+ ):
+ super().__init__()
+
+ self.block_index = block_index
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ # the first block has a zero before layer
+ if self.block_index == 0:
+ self.before_proj = nn.Linear(self.inner_dim, self.inner_dim)
+ nn.init.zeros_(self.before_proj.weight)
+ nn.init.zeros_(self.before_proj.bias)
+
+ self.transformer_block = BasicTransformerBlock(
+ self.inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ )
+
+ self.after_proj = nn.Linear(self.inner_dim, self.inner_dim)
+ nn.init.zeros_(self.after_proj.weight)
+ nn.init.zeros_(self.after_proj.bias)
+
+ def train(self, mode: bool = True):
+ self.transformer_block.train(mode)
+
+ if self.block_index == 0:
+ self.before_proj.train(mode)
+
+ self.after_proj.train(mode)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ controlnet_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if self.block_index == 0:
+ controlnet_states = self.before_proj(controlnet_states)
+ controlnet_states = hidden_states + controlnet_states
+
+ controlnet_states_down = self.transformer_block(
+ hidden_states=controlnet_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ class_labels=None,
+ )
+
+ controlnet_states_left = self.after_proj(controlnet_states_down)
+
+ return controlnet_states_left, controlnet_states_down
+
+
+class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin):
+ # N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer
+ @register_to_config
+ def __init__(self, num_layers=13) -> None:
+ super().__init__()
+
+ self.num_layers = num_layers
+
+ self.controlnet_blocks = nn.ModuleList(
+ [PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)]
+ )
+
+ @classmethod
+ def from_transformer(cls, transformer: PixArtTransformer2DModel):
+ control_net = PixArtControlNetAdapterModel()
+
+ # copied the specified number of blocks from the transformer
+ for depth in range(control_net.num_layers):
+ control_net.controlnet_blocks[depth].transformer_block.load_state_dict(
+ transformer.transformer_blocks[depth].state_dict()
+ )
+
+ return control_net
+
+ def train(self, mode: bool = True):
+ for block in self.controlnet_blocks:
+ block.train(mode)
+
+
+class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
+ def __init__(
+ self,
+ transformer: PixArtTransformer2DModel,
+ controlnet: PixArtControlNetAdapterModel,
+ blocks_num=13,
+ init_from_transformer=False,
+ training=False,
+ ):
+ super().__init__()
+
+ self.blocks_num = blocks_num
+ self.gradient_checkpointing = False
+ self.register_to_config(**transformer.config)
+ self.training = training
+
+ if init_from_transformer:
+ # copies the specified number of blocks from the transformer
+ controlnet.from_transformer(transformer, self.blocks_num)
+
+ self.transformer = transformer
+ self.controlnet = controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ controlnet_cond: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ if self.transformer.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch_size = hidden_states.shape[0]
+ height, width = (
+ hidden_states.shape[-2] // self.transformer.config.patch_size,
+ hidden_states.shape[-1] // self.transformer.config.patch_size,
+ )
+ hidden_states = self.transformer.pos_embed(hidden_states)
+
+ timestep, embedded_timestep = self.transformer.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ if self.transformer.caption_projection is not None:
+ encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ controlnet_states_down = None
+ if controlnet_cond is not None:
+ controlnet_states_down = self.transformer.pos_embed(controlnet_cond)
+
+ # 2. Blocks
+ for block_index, block in enumerate(self.transformer.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ # rc todo: for training and gradient checkpointing
+ print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
+ exit(1)
+
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ cross_attention_kwargs,
+ None,
+ )
+ else:
+ # the control nets are only used for the blocks 1 to self.blocks_num
+ if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None:
+ controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[
+ block_index - 1
+ ](
+ hidden_states=hidden_states, # used only in the first block
+ controlnet_states=controlnet_states_down,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ hidden_states = hidden_states + controlnet_states_left
+
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=None,
+ )
+
+ # 3. Output
+ shift, scale = (
+ self.transformer.scale_shift_table[None]
+ + embedded_timestep[:, None].to(self.transformer.scale_shift_table.device)
+ ).chunk(2, dim=1)
+ hidden_states = self.transformer.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
+ hidden_states = self.transformer.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ hidden_states = hidden_states.reshape(
+ shape=(
+ -1,
+ height,
+ width,
+ self.transformer.config.patch_size,
+ self.transformer.config.patch_size,
+ self.transformer.out_channels,
+ )
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(
+ -1,
+ self.transformer.out_channels,
+ height * self.transformer.config.patch_size,
+ width * self.transformer.config.patch_size,
+ )
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
new file mode 100644
index 000000000000..4065a854c22d
--- /dev/null
+++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
@@ -0,0 +1,1098 @@
+# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.image_processor import PipelineImageInput, PixArtImageProcessor
+from diffusers.models import AutoencoderKL, PixArtTransformer2DModel
+from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput
+from diffusers.schedulers import DPMSolverMultistepScheduler
+from diffusers.utils import (
+ BACKENDS_MAPPING,
+ deprecate,
+ is_bs4_available,
+ is_ftfy_available,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import PixArtAlphaPipeline
+
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
+ >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+ASPECT_RATIO_1024_BIN = {
+ "0.25": [512.0, 2048.0],
+ "0.28": [512.0, 1856.0],
+ "0.32": [576.0, 1792.0],
+ "0.33": [576.0, 1728.0],
+ "0.35": [576.0, 1664.0],
+ "0.4": [640.0, 1600.0],
+ "0.42": [640.0, 1536.0],
+ "0.48": [704.0, 1472.0],
+ "0.5": [704.0, 1408.0],
+ "0.52": [704.0, 1344.0],
+ "0.57": [768.0, 1344.0],
+ "0.6": [768.0, 1280.0],
+ "0.68": [832.0, 1216.0],
+ "0.72": [832.0, 1152.0],
+ "0.78": [896.0, 1152.0],
+ "0.82": [896.0, 1088.0],
+ "0.88": [960.0, 1088.0],
+ "0.94": [960.0, 1024.0],
+ "1.0": [1024.0, 1024.0],
+ "1.07": [1024.0, 960.0],
+ "1.13": [1088.0, 960.0],
+ "1.21": [1088.0, 896.0],
+ "1.29": [1152.0, 896.0],
+ "1.38": [1152.0, 832.0],
+ "1.46": [1216.0, 832.0],
+ "1.67": [1280.0, 768.0],
+ "1.75": [1344.0, 768.0],
+ "2.0": [1408.0, 704.0],
+ "2.09": [1472.0, 704.0],
+ "2.4": [1536.0, 640.0],
+ "2.5": [1600.0, 640.0],
+ "3.0": [1728.0, 576.0],
+ "4.0": [2048.0, 512.0],
+}
+
+ASPECT_RATIO_512_BIN = {
+ "0.25": [256.0, 1024.0],
+ "0.28": [256.0, 928.0],
+ "0.32": [288.0, 896.0],
+ "0.33": [288.0, 864.0],
+ "0.35": [288.0, 832.0],
+ "0.4": [320.0, 800.0],
+ "0.42": [320.0, 768.0],
+ "0.48": [352.0, 736.0],
+ "0.5": [352.0, 704.0],
+ "0.52": [352.0, 672.0],
+ "0.57": [384.0, 672.0],
+ "0.6": [384.0, 640.0],
+ "0.68": [416.0, 608.0],
+ "0.72": [416.0, 576.0],
+ "0.78": [448.0, 576.0],
+ "0.82": [448.0, 544.0],
+ "0.88": [480.0, 544.0],
+ "0.94": [480.0, 512.0],
+ "1.0": [512.0, 512.0],
+ "1.07": [512.0, 480.0],
+ "1.13": [544.0, 480.0],
+ "1.21": [544.0, 448.0],
+ "1.29": [576.0, 448.0],
+ "1.38": [576.0, 416.0],
+ "1.46": [608.0, 416.0],
+ "1.67": [640.0, 384.0],
+ "1.75": [672.0, 384.0],
+ "2.0": [704.0, 352.0],
+ "2.09": [736.0, 352.0],
+ "2.4": [768.0, 320.0],
+ "2.5": [800.0, 320.0],
+ "3.0": [864.0, 288.0],
+ "4.0": [1024.0, 256.0],
+}
+
+ASPECT_RATIO_256_BIN = {
+ "0.25": [128.0, 512.0],
+ "0.28": [128.0, 464.0],
+ "0.32": [144.0, 448.0],
+ "0.33": [144.0, 432.0],
+ "0.35": [144.0, 416.0],
+ "0.4": [160.0, 400.0],
+ "0.42": [160.0, 384.0],
+ "0.48": [176.0, 368.0],
+ "0.5": [176.0, 352.0],
+ "0.52": [176.0, 336.0],
+ "0.57": [192.0, 336.0],
+ "0.6": [192.0, 320.0],
+ "0.68": [208.0, 304.0],
+ "0.72": [208.0, 288.0],
+ "0.78": [224.0, 288.0],
+ "0.82": [224.0, 272.0],
+ "0.88": [240.0, 272.0],
+ "0.94": [240.0, 256.0],
+ "1.0": [256.0, 256.0],
+ "1.07": [256.0, 240.0],
+ "1.13": [272.0, 240.0],
+ "1.21": [272.0, 224.0],
+ "1.29": [288.0, 224.0],
+ "1.38": [288.0, 208.0],
+ "1.46": [304.0, 208.0],
+ "1.67": [320.0, 192.0],
+ "1.75": [336.0, 192.0],
+ "2.0": [352.0, 176.0],
+ "2.09": [368.0, 176.0],
+ "2.4": [384.0, 160.0],
+ "2.5": [400.0, 160.0],
+ "3.0": [432.0, 144.0],
+ "4.0": [512.0, 128.0],
+}
+
+
+def get_closest_hw(width, height, image_size):
+ if image_size == 1024:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ elif image_size == 512:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ else:
+ raise ValueError("Invalid image size")
+
+ height, width = PixArtImageProcessor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ return width, height
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class PixArtAlphaControlnetPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`PixArtTransformer2DModel`]):
+ A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ bad_punct_regex = re.compile(
+ r"["
+ + "#®•©™&@·º½¾¿¡§~"
+ + r"\)"
+ + r"\("
+ + r"\]"
+ + r"\["
+ + r"\}"
+ + r"\{"
+ + r"\|"
+ + "\\"
+ + r"\/"
+ + r"\*"
+ + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKL,
+ transformer: PixArtTransformer2DModel,
+ controlnet: PixArtControlNetAdapterModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ # change to the controlnet transformer model
+ transformer = PixArtControlNetTransformerModel(transformer=transformer, controlnet=controlnet)
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 120,
+ **kwargs,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
+ """
+
+ if "mask_feature" in kwargs:
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because T5 can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.controlnet.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ negative_prompt_attention_mask = uncond_input.attention_mask
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ image=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ if image is not None:
+ self.check_image(image, prompt, prompt_embeds)
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # based on pipeline_pixart_inpaiting.py
+ def prepare_image_latents(self, image, device, dtype):
+ image = image.to(device=device, dtype=dtype)
+
+ image_latents = self.vae.encode(image).latent_dist.sample()
+ image_latents = image_latents * self.vae.config.scaling_factor
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ # rc todo: controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ # rc todo: control_guidance_start = 0.0,
+ # rc todo: control_guidance_end = 1.0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ use_resolution_binning: bool = True,
+ max_sequence_length: int = 120,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
+ the requested resolution. Useful for generating non-square images.
+ max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ if "mask_feature" in kwargs:
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+ # 1. Check inputs. Raise error if not correct
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 128:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ elif self.transformer.config.sample_size == 64:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ elif self.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_256_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ image,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 4.1 Prepare image
+ image_latents = None
+ if image is not None:
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.transformer.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ image_latents = self.prepare_image_latents(image, device, self.transformer.controlnet.dtype)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.transformer.config.sample_size == 128:
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+
+ if do_classifier_free_guidance:
+ resolution = torch.cat([resolution, resolution], dim=0)
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
+
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ is_npu = latent_model_input.device.type == "npu"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ else:
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=current_timestep,
+ controlnet_cond=image_latents,
+ # rc todo: controlnet_conditioning_scale=1.0,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ if num_inference_steps == 1:
+ # For DMD one step sampling: https://arxiv.org/abs/2311.18828
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/examples/research_projects/pixart/requirements.txt b/examples/research_projects/pixart/requirements.txt
new file mode 100644
index 000000000000..2b307927ee9f
--- /dev/null
+++ b/examples/research_projects/pixart/requirements.txt
@@ -0,0 +1,6 @@
+transformers
+SentencePiece
+torchvision
+controlnet-aux
+datasets
+# wandb
\ No newline at end of file
diff --git a/examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py b/examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py
new file mode 100644
index 000000000000..0014c590541b
--- /dev/null
+++ b/examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py
@@ -0,0 +1,75 @@
+import torch
+import torchvision.transforms as T
+from controlnet_aux import HEDdetector
+
+from diffusers.utils import load_image
+from examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
+from examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
+
+
+controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
+
+weight_dtype = torch.float16
+image_size = 1024
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+torch.manual_seed(0)
+
+# load controlnet
+controlnet = PixArtControlNetAdapterModel.from_pretrained(
+ controlnet_repo_id,
+ torch_dtype=weight_dtype,
+ use_safetensors=True,
+).to(device)
+
+pipe = PixArtAlphaControlnetPipeline.from_pretrained(
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
+ controlnet=controlnet,
+ torch_dtype=weight_dtype,
+ use_safetensors=True,
+).to(device)
+
+images_path = "images"
+control_image_file = "0_7.jpg"
+
+# prompt = "cinematic photo of superman in action . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
+# prompt = "yellow modern car, city in background, beautiful rainy day"
+# prompt = "modern villa, clear sky, suny day . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
+# prompt = "robot dog toy in park . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
+# prompt = "purple car, on highway, beautiful sunny day"
+# prompt = "realistical photo of a loving couple standing in the open kitchen of the living room, cooking ."
+prompt = "battleship in space, galaxy in background"
+
+control_image_name = control_image_file.split(".")[0]
+
+control_image = load_image(f"{images_path}/{control_image_file}")
+print(control_image.size)
+height, width = control_image.size
+
+hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
+
+condition_transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB")),
+ T.CenterCrop([image_size, image_size]),
+ ]
+)
+
+control_image = condition_transform(control_image)
+hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
+
+hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
+
+# run pipeline
+with torch.no_grad():
+ out = pipe(
+ prompt=prompt,
+ image=hed_edge,
+ num_inference_steps=14,
+ guidance_scale=4.5,
+ height=image_size,
+ width=image_size,
+ )
+
+ out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
diff --git a/examples/research_projects/pixart/train_controlnet_hf_diffusers.sh b/examples/research_projects/pixart/train_controlnet_hf_diffusers.sh
new file mode 100755
index 000000000000..0abd88f19e18
--- /dev/null
+++ b/examples/research_projects/pixart/train_controlnet_hf_diffusers.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+# run
+# accelerate config
+
+# check with
+# accelerate env
+
+export MODEL_DIR="PixArt-alpha/PixArt-XL-2-512x512"
+export OUTPUT_DIR="output/pixart-controlnet-hf-diffusers-test"
+
+accelerate launch ./train_pixart_controlnet_hf.py --mixed_precision="fp16" \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --report_to="wandb" \
+ --seed=42 \
+ --dataloader_num_workers=8
+# --lr_scheduler="cosine" --lr_warmup_steps=0 \
diff --git a/examples/research_projects/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py
new file mode 100644
index 000000000000..67ec30da0ece
--- /dev/null
+++ b/examples/research_projects/pixart/train_pixart_controlnet_hf.py
@@ -0,0 +1,1174 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with HuggingFace diffusers."""
+
+import argparse
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import T5EncoderModel, T5Tokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler
+from diffusers.models import PixArtTransformer2DModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+from examples.research_projects.pixart.controlnet_pixart_alpha import (
+ PixArtControlNetAdapterModel,
+ PixArtControlNetTransformerModel,
+)
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.29.2")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def log_validation(
+ vae,
+ transformer,
+ controlnet,
+ tokenizer,
+ scheduler,
+ text_encoder,
+ args,
+ accelerator,
+ weight_dtype,
+ step,
+ is_final_validation=False,
+):
+ if weight_dtype == torch.float16 or weight_dtype == torch.bfloat16:
+ raise ValueError(
+ "Validation is not supported with mixed precision training, disable validation and use the validation script, that will generate images from the saved checkpoints."
+ )
+
+ if not is_final_validation:
+ logger.info(f"Running validation step {step} ... ")
+
+ controlnet = accelerator.unwrap_model(controlnet)
+ pipeline = PixArtAlphaControlnetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ controlnet=controlnet,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ else:
+ logger.info("Running validation - final ... ")
+
+ controlnet = PixArtControlNetAdapterModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
+
+ pipeline = PixArtAlphaControlnetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ controlnet=controlnet,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ image = pipeline(
+ prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = [np.asarray(validation_image)]
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ logger.info("Validation done!!")
+
+ return image_logs
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, dataset_name=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"\n"
+
+ model_description = f"""
+# controlnet-{repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "pixart-alpha",
+ "pixart-alpha-diffusers",
+ "text-to-image",
+ "diffusers",
+ "controlnet",
+ "diffusers-training",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from the transformer.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ nargs="+",
+ default=None,
+ help="One or more prompts to be evaluated every `--validation_steps`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.",
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="pixart-controlnet",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="pixart_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # See Section 3.1. of the paper.
+ max_length = 120
+
+ # For mixed precision training we cast all non-trainable weigths (vae, text_encoder) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler", torch_dtype=weight_dtype
+ )
+ tokenizer = T5Tokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, torch_dtype=weight_dtype
+ )
+
+ text_encoder = T5EncoderModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, torch_dtype=weight_dtype
+ )
+ text_encoder.requires_grad_(False)
+ text_encoder.to(accelerator.device)
+
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ vae.requires_grad_(False)
+ vae.to(accelerator.device)
+
+ transformer = PixArtTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer")
+ transformer.to(accelerator.device)
+ transformer.requires_grad_(False)
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = PixArtControlNetAdapterModel.from_pretrained(args.controlnet_model_name_or_path)
+ else:
+ logger.info("Initializing controlnet weights from transformer.")
+ controlnet = PixArtControlNetAdapterModel.from_transformer(transformer)
+
+ transformer.to(dtype=weight_dtype)
+
+ controlnet.to(accelerator.device)
+ controlnet.train()
+
+ def unwrap_model(model, keep_fp32_wrapper=True):
+ model = accelerator.unwrap_model(model, keep_fp32_wrapper=keep_fp32_wrapper)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for _, model in enumerate(models):
+ if isinstance(model, PixArtControlNetTransformerModel):
+ print(f"Saving model {model.__class__.__name__} to {output_dir}")
+ model.controlnet.save_pretrained(os.path.join(output_dir, "controlnet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ # rc todo: test and load the controlenet adapter and transformer
+ raise ValueError("load model hook not tested")
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ if isinstance(model, PixArtControlNetTransformerModel):
+ load_model = PixArtControlNetAdapterModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ transformer.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Transformer loaded as datatype {unwrap_model(controlnet).dtype}. The trainable parameters should be in torch.float32."
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ controlnet.enable_gradient_checkpointing()
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ params_to_optimize = controlnet.parameters()
+ optimizer = optimizer_cls(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True, proportion_empty_prompts=0.0, max_length=120):
+ captions = []
+ for caption in examples[caption_column]:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
+ return inputs.input_ids, inputs.attention_mask
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]]
+ examples["conditioning_pixel_values"] = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["input_ids"], examples["prompt_attention_mask"] = tokenize_captions(
+ examples, proportion_empty_prompts=args.proportion_empty_prompts, max_length=max_length
+ )
+
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ prompt_attention_mask = torch.stack([example["prompt_attention_mask"] for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "input_ids": input_ids,
+ "prompt_attention_mask": prompt_attention_mask,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ controlnet_transformer = PixArtControlNetTransformerModel(transformer, controlnet, training=True)
+ controlnet_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ controlnet_transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers(args.tracker_project_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ latent_channels = transformer.config.in_channels
+ for epoch in range(first_epoch, args.num_train_epochs):
+ controlnet_transformer.controlnet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(controlnet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Convert control images to latent space
+ controlnet_image_latents = vae.encode(
+ batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+ ).latent_dist.sample()
+ controlnet_image_latents = controlnet_image_latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ prompt_embeds = text_encoder(batch["input_ids"], attention_mask=batch["prompt_attention_mask"])[0]
+ prompt_attention_mask = batch["prompt_attention_mask"]
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if getattr(transformer, "module", transformer).config.sample_size == 128:
+ resolution = torch.tensor([args.resolution, args.resolution]).repeat(bsz, 1)
+ aspect_ratio = torch.tensor([float(args.resolution / args.resolution)]).repeat(bsz, 1)
+ resolution = resolution.to(dtype=weight_dtype, device=latents.device)
+ aspect_ratio = aspect_ratio.to(dtype=weight_dtype, device=latents.device)
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # Predict the noise residual and compute loss
+ model_pred = controlnet_transformer(
+ noisy_latents,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timesteps,
+ controlnet_cond=controlnet_image_latents,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ if transformer.config.out_channels // 2 == latent_channels:
+ model_pred = model_pred.chunk(2, dim=1)[0]
+ else:
+ model_pred = model_pred
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = controlnet_transformer.controlnet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ log_validation(
+ vae,
+ transformer,
+ controlnet_transformer.controlnet,
+ tokenizer,
+ noise_scheduler,
+ text_encoder,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ is_final_validation=False,
+ )
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ controlnet = unwrap_model(controlnet_transformer.controlnet, keep_fp32_wrapper=False)
+ controlnet.save_pretrained(os.path.join(args.output_dir, "controlnet"))
+
+ image_logs = None
+ if args.validation_prompt is not None:
+ image_logs = log_validation(
+ vae,
+ transformer,
+ controlnet,
+ tokenizer,
+ noise_scheduler,
+ text_encoder,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index cb4260d4653f..19c1f30d82da 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -233,7 +233,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
diff --git a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py
index 46cabd863dfa..7853695f0566 100644
--- a/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py
+++ b/examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py
@@ -229,11 +229,11 @@ def forward(
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+ Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
@@ -258,10 +258,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
diff --git a/examples/research_projects/pytorch_xla/inference/flux/README.md b/examples/research_projects/pytorch_xla/inference/flux/README.md
new file mode 100644
index 000000000000..9d482e6805a3
--- /dev/null
+++ b/examples/research_projects/pytorch_xla/inference/flux/README.md
@@ -0,0 +1,166 @@
+# Generating images using Flux and PyTorch/XLA
+
+The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
+
+## Create TPU
+
+To create a TPU on Google Cloud, follow [this guide](https://cloud.google.com/tpu/docs/v6e)
+
+## Setup TPU environment
+
+SSH into the VM and install Pytorch, Pytorch/XLA
+
+```bash
+pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
+pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
+```
+
+Verify that PyTorch and PyTorch/XLA were installed correctly:
+
+```bash
+python3 -c "import torch; import torch_xla;"
+```
+
+Clone the diffusers repo and install dependencies
+
+```bash
+git clone https://github.com/huggingface/diffusers.git
+cd diffusers
+pip install transformers accelerate sentencepiece structlog
+pip install .
+cd examples/research_projects/pytorch_xla/inference/flux/
+```
+
+## Run the inference job
+
+### Authenticate
+
+**Gated Model**
+
+As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
+
+```bash
+huggingface-cli login
+```
+
+Then run:
+
+```bash
+python flux_inference.py
+```
+
+The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
+
+On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):
+
+```bash
+WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
+Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 7.06it/s]
+Loading pipeline components...: 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.28it/s]
+2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
+2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
+Loading pipeline components...: 0%| | 0/3 [00:00, ?it/s]2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
+2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
+Loading pipeline components...: 0%| | 0/3 [00:00, ?it/s]2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
+2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.66it/s]
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 4.48it/s]
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.32it/s]
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.69it/s]
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.74it/s]
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.10it/s]
+2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
+Loading pipeline components...: 0%| | 0/3 [00:00, ?it/s]2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.55it/s]
+Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.46it/s]
+2025-03-14 21:18:34 [info ] starting compilation run...
+2025-03-14 21:18:37 [info ] starting compilation run...
+2025-03-14 21:18:38 [info ] starting compilation run...
+2025-03-14 21:18:39 [info ] starting compilation run...
+2025-03-14 21:18:41 [info ] starting compilation run...
+2025-03-14 21:18:41 [info ] starting compilation run...
+2025-03-14 21:18:42 [info ] starting compilation run...
+2025-03-14 21:18:43 [info ] starting compilation run...
+ 82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]
+2025-03-14 21:36:38 [info ] compilation took 1079.3314765350078 sec.
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
+2025-03-14 21:36:38 [info ] starting inference run...
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
+2025-03-14 21:36:38 [info ] compilation took 1081.89390801001 sec.
+2025-03-14 21:36:39 [info ] starting inference run...
+2025-03-14 21:36:39 [info ] compilation took 1077.1543154849933 sec.
+2025-03-14 21:36:39 [info ] compilation took 1075.7239800530078 sec.
+2025-03-14 21:36:39 [info ] starting inference run...
+2025-03-14 21:36:40 [info ] starting inference run...
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]
+2025-03-14 21:36:50 [info ] compilation took 1088.1632604240003 sec.
+2025-03-14 21:36:50 [info ] starting inference run...
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]
+2025-03-14 21:36:55 [info ] compilation took 1096.8027802760043 sec.
+2025-03-14 21:36:56 [info ] starting inference run...
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]
+2025-03-14 21:37:08 [info ] compilation took 1113.8591305939917 sec.
+2025-03-14 21:37:08 [info ] starting inference run...
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]
+2025-03-14 21:37:22 [info ] compilation took 1120.5590810020076 sec.
+2025-03-14 21:37:22 [info ] starting inference run...
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.00it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.98it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.08it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.82it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00, 3.34it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.22it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.09it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.41it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.50it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.10it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.27it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 4.80it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.67it/s]
+ 29%|█████████████████████████████████████████████████████████████████████████████▍ | 8/28 [00:01<00:03, 6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
+ images = (images * 255).round().astype("uint8")
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.82it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.93it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.98it/s]
+ 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.03it/s]2025-03-14 21:38:32 [info ] inference time: 5.962021178987925
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
+2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.2685392687970305 sec.
+2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.402720856998348 sec.
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.01it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.06it/s]
+ 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.01it/s]2025-03-14 21:38:38 [info ] inference time: 5.950578948002658
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.87it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.00it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.86it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.99it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.05it/s]
+2025-03-14 21:38:43 [info ] avg. inference over 5 iterations took 6.763298449796276 sec.
+ 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.04it/s]2025-03-14 21:38:44 [info ] inference time: 5.949129879008979
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.92it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.10it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
+ 39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████ | 11/28 [00:01<00:02, 5.98it/s]2025-03-14 21:38:46 [info ] avg. inference over 5 iterations took 7.221068455604836 sec.
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.08it/s]
+ 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 26/28 [00:04<00:00, 5.92it/s]2025-03-14 21:38:50 [info ] inference time: 5.954778069004533
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.90it/s]
+ 11%|█████████████████████████████ | 3/28 [00:00<00:04, 6.03it/s]2025-03-14 21:38:50 [info ] avg. inference over 5 iterations took 6.05970350120042 sec.
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
+ 32%|███████████████████████████████████████████████████████████████████████████████████████ | 9/28 [00:01<00:03, 5.99it/s]2025-03-14 21:38:51 [info ] avg. inference over 5 iterations took 6.018543455796316 sec.
+ 54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 15/28 [00:02<00:02, 6.00it/s]2025-03-14 21:38:52 [info ] avg. inference over 5 iterations took 5.9609976705978625 sec.
+100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.97it/s]
+2025-03-14 21:38:56 [info ] inference time: 5.944058528999449
+2025-03-14 21:38:56 [info ] avg. inference over 5 iterations took 5.952113320800708 sec.
+2025-03-14 21:38:56 [info ] saved metric information as /tmp/metrics_report.txt
+```
\ No newline at end of file
diff --git a/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py
new file mode 100644
index 000000000000..9c98c9b5ff4f
--- /dev/null
+++ b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py
@@ -0,0 +1,134 @@
+from argparse import ArgumentParser
+from pathlib import Path
+from time import perf_counter
+
+import structlog
+import torch
+import torch_xla.core.xla_model as xm
+import torch_xla.debug.metrics as met
+import torch_xla.debug.profiler as xp
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.runtime as xr
+from torch_xla.experimental.custom_kernel import FlashAttention
+
+from diffusers import FluxPipeline
+
+
+logger = structlog.get_logger()
+metrics_filepath = "/tmp/metrics_report.txt"
+
+
+def _main(index, args, text_pipe, ckpt_id):
+ cache_path = Path("/tmp/data/compiler_cache_tRiLlium_eXp")
+ cache_path.mkdir(parents=True, exist_ok=True)
+ xr.initialize_cache(str(cache_path), readonly=False)
+
+ profile_path = Path("/tmp/data/profiler_out_tRiLlium_eXp")
+ profile_path.mkdir(parents=True, exist_ok=True)
+ profiler_port = 9012
+ profile_duration = args.profile_duration
+ if args.profile:
+ logger.info(f"starting profiler on port {profiler_port}")
+ _ = xp.start_server(profiler_port)
+ device0 = xm.xla_device()
+
+ logger.info(f"loading flux from {ckpt_id}")
+ flux_pipe = FluxPipeline.from_pretrained(
+ ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
+ ).to(device0)
+ flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
+ FlashAttention.DEFAULT_BLOCK_SIZES = {
+ "block_q": 1536,
+ "block_k_major": 1536,
+ "block_k": 1536,
+ "block_b": 1536,
+ "block_q_major_dkv": 1536,
+ "block_k_major_dkv": 1536,
+ "block_q_dkv": 1536,
+ "block_k_dkv": 1536,
+ "block_q_dq": 1536,
+ "block_k_dq": 1536,
+ "block_k_major_dq": 1536,
+ }
+
+ prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
+ width = args.width
+ height = args.height
+ guidance = args.guidance
+ n_steps = 4 if args.schnell else 28
+
+ logger.info("starting compilation run...")
+ ts = perf_counter()
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
+ prompt=prompt, prompt_2=None, max_sequence_length=512
+ )
+ prompt_embeds = prompt_embeds.to(device0)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
+
+ image = flux_pipe(
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_inference_steps=28,
+ guidance_scale=guidance,
+ height=height,
+ width=width,
+ ).images[0]
+ logger.info(f"compilation took {perf_counter() - ts} sec.")
+ image.save("/tmp/compile_out.png")
+
+ base_seed = 4096 if args.seed is None else args.seed
+ seed_range = 1000
+ unique_seed = base_seed + index * seed_range
+ xm.set_rng_state(seed=unique_seed, device=device0)
+ times = []
+ logger.info("starting inference run...")
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
+ prompt=prompt, prompt_2=None, max_sequence_length=512
+ )
+ prompt_embeds = prompt_embeds.to(device0)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
+ for _ in range(args.itters):
+ ts = perf_counter()
+
+ if args.profile:
+ xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
+ image = flux_pipe(
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ num_inference_steps=n_steps,
+ guidance_scale=guidance,
+ height=height,
+ width=width,
+ ).images[0]
+ inference_time = perf_counter() - ts
+ if index == 0:
+ logger.info(f"inference time: {inference_time}")
+ times.append(inference_time)
+ logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
+ image.save(f"/tmp/inference_out-{index}.png")
+ if index == 0:
+ metrics_report = met.metrics_report()
+ with open(metrics_filepath, "w+") as fout:
+ fout.write(metrics_report)
+ logger.info(f"saved metric information as {metrics_filepath}")
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
+ parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
+ parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
+ parser.add_argument("--guidance", type=float, default=3.5, help="gauidance strentgh for dev")
+ parser.add_argument("--seed", type=int, default=None, help="seed for inference")
+ parser.add_argument("--profile", action="store_true", help="enable profiling")
+ parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
+ parser.add_argument("--itters", type=int, default=15, help="tiems to run inference and get avg time in sec.")
+ args = parser.parse_args()
+ if args.schnell:
+ ckpt_id = "black-forest-labs/FLUX.1-schnell"
+ else:
+ ckpt_id = "black-forest-labs/FLUX.1-dev"
+ text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu")
+ xmp.spawn(_main, args=(args, text_pipe, ckpt_id))
diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/training/text_to_image/README.md
similarity index 87%
rename from examples/research_projects/pytorch_xla/README.md
rename to examples/research_projects/pytorch_xla/training/text_to_image/README.md
index a6901d5ada9d..06013b8a61e0 100644
--- a/examples/research_projects/pytorch_xla/README.md
+++ b/examples/research_projects/pytorch_xla/training/text_to_image/README.md
@@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
where we shard the input batches over the TPU devices.
-As of 9-11-2024, these are some expected step times.
+As of 10-31-2024, these are some expected step times.
| accelerator | global batch size | step time (seconds) |
| ----------- | ----------------- | --------- |
-| v5p-128 | 1024 | 0.245 |
-| v5p-256 | 2048 | 0.234 |
-| v5p-512 | 4096 | 0.2498 |
+| v5p-512 | 16384 | 1.01 |
+| v5p-256 | 8192 | 1.01 |
+| v5p-128 | 4096 | 1.0 |
+| v5p-64 | 2048 | 1.01 |
## Create TPU
@@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='
-pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
-pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
+pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
+pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
+pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
'
```
@@ -88,17 +90,18 @@ are fixed.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='
-export XLA_DISABLE_FUNCTIONALIZATION=1
+export XLA_DISABLE_FUNCTIONALIZATION=0
export PROFILE_DIR=/tmp/
export CACHE_DIR=/tmp/
export DATASET_NAME=lambdalabs/naruto-blip-captions
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
export TRAIN_STEPS=50
export OUTPUT_DIR=/tmp/trained-model/
-python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
-
+python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
```
+Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
+
### Environment Envs Explained
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
diff --git a/examples/research_projects/pytorch_xla/requirements.txt b/examples/research_projects/pytorch_xla/training/text_to_image/requirements.txt
similarity index 100%
rename from examples/research_projects/pytorch_xla/requirements.txt
rename to examples/research_projects/pytorch_xla/training/text_to_image/requirements.txt
diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
similarity index 89%
rename from examples/research_projects/pytorch_xla/train_text_to_image_xla.py
rename to examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
index 5d9d8c540f11..9719585d3dfb 100644
--- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py
+++ b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
@@ -140,33 +140,43 @@ def run_optimizer(self):
self.optimizer.step()
def start_training(self):
- times = []
- last_time = time.time()
- step = 0
- while True:
- if self.global_step >= self.args.max_train_steps:
- xm.mark_step()
- break
- if step == 4 and PROFILE_DIR is not None:
- xm.wait_device_ops()
- xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
+ dataloader_exception = False
+ measure_start_step = args.measure_start_step
+ assert measure_start_step < self.args.max_train_steps
+ total_time = 0
+ for step in range(0, self.args.max_train_steps):
try:
batch = next(self.dataloader)
except Exception as e:
+ dataloader_exception = True
print(e)
break
+ if step == measure_start_step and PROFILE_DIR is not None:
+ xm.wait_device_ops()
+ xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
+ last_time = time.time()
loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
- step_time = time.time() - last_time
- if step >= 10:
- times.append(step_time)
- print(f"step: {step}, step_time: {step_time}")
- if step % 5 == 0:
- print(f"step: {step}, loss: {loss}")
- last_time = time.time()
self.global_step += 1
- step += 1
- # print(f"Average step time: {sum(times)/len(times)}")
- xm.wait_device_ops()
+
+ def print_loss_closure(step, loss):
+ print(f"Step: {step}, Loss: {loss}")
+
+ if args.print_loss:
+ xm.add_step_closure(
+ print_loss_closure,
+ args=(
+ self.global_step,
+ loss,
+ ),
+ )
+ xm.mark_step()
+ if not dataloader_exception:
+ xm.wait_device_ops()
+ total_time = time.time() - last_time
+ print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
+ else:
+ print("dataloader exception happen, skip result")
+ return
def step_fn(
self,
@@ -180,7 +190,10 @@ def step_fn(
noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)
bsz = latents.shape[0]
timesteps = torch.randint(
- 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ 0,
+ self.noise_scheduler.config.num_train_timesteps,
+ (bsz,),
+ device=latents.device,
)
timesteps = timesteps.long()
@@ -224,9 +237,6 @@ def step_fn(
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
- parser.add_argument(
- "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
- )
parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
parser.add_argument(
"--pretrained_model_name_or_path",
@@ -258,12 +268,6 @@ def parse_args():
" or to a folder containing files that 🤗 Datasets can understand."
),
)
- parser.add_argument(
- "--dataset_config_name",
- type=str,
- default=None,
- help="The config of the Dataset, leave as None if there's only one config.",
- )
parser.add_argument(
"--train_data_dir",
type=str,
@@ -283,15 +287,6 @@ def parse_args():
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
- parser.add_argument(
- "--max_train_samples",
- type=int,
- default=None,
- help=(
- "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
- ),
- )
parser.add_argument(
"--output_dir",
type=str,
@@ -304,7 +299,6 @@ def parse_args():
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
@@ -374,12 +368,19 @@ def parse_args():
default=1,
help=("Number of subprocesses to use for data loading to cpu."),
)
+ parser.add_argument(
+ "--loader_prefetch_factor",
+ type=int,
+ default=2,
+ help=("Number of batches loaded in advance by each worker."),
+ )
parser.add_argument(
"--device_prefetch_size",
type=int,
default=1,
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
)
+ parser.add_argument("--measure_start_step", type=int, default=10, help="Step to start profiling.")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
@@ -394,12 +395,8 @@ def parse_args():
"--mixed_precision",
type=str,
default=None,
- choices=["no", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
- ),
+ choices=["no", "bf16"],
+ help=("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"),
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
@@ -409,6 +406,12 @@ def parse_args():
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
+ parser.add_argument(
+ "--print_loss",
+ default=False,
+ action="store_true",
+ help=("Print loss at every step."),
+ )
args = parser.parse_args()
@@ -436,7 +439,6 @@ def load_dataset(args):
# Downloading and loading a dataset from the hub.
dataset = datasets.load_dataset(
args.dataset_name,
- args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
)
@@ -481,9 +483,7 @@ def main(args):
_ = xp.start_server(PORT)
num_devices = xr.global_runtime_device_count()
- device_ids = np.arange(num_devices)
- mesh_shape = (num_devices, 1)
- mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
+ mesh = xs.get_1d_mesh("data")
xs.set_global_mesh(mesh)
text_encoder = CLIPTextModel.from_pretrained(
@@ -520,6 +520,7 @@ def main(args):
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
+ unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -530,15 +531,12 @@ def main(args):
# as these weights are only used for inference, keeping weights in full
# precision is not required.
weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif args.mixed_precision == "bf16":
+ if args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
device = xm.xla_device()
- print("device: ", device)
- print("weight_dtype: ", weight_dtype)
+ # Move text_encode and vae to device and cast to weight_dtype
text_encoder = text_encoder.to(device, dtype=weight_dtype)
vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=weight_dtype)
@@ -606,24 +604,27 @@ def collate_fn(examples):
collate_fn=collate_fn,
num_workers=args.dataloader_num_workers,
batch_size=args.train_batch_size,
+ prefetch_factor=args.loader_prefetch_factor,
)
train_dataloader = pl.MpDeviceLoader(
train_dataloader,
device,
input_sharding={
- "pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True),
- "input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
+ "pixel_values": xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True),
+ "input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
},
loader_prefetch_size=args.loader_prefetch_size,
device_prefetch_size=args.device_prefetch_size,
)
+ num_hosts = xr.process_count()
+ num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal():
print("***** Running training *****")
- print(f"Instantaneous batch size per device = {args.train_batch_size}")
+ print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
print(
- f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}"
+ f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
)
print(f" Total optimization steps = {args.max_train_steps}")
diff --git a/examples/research_projects/rdm/pipeline_rdm.py b/examples/research_projects/rdm/pipeline_rdm.py
index f8093a3f217d..e84568786f50 100644
--- a/examples/research_projects/rdm/pipeline_rdm.py
+++ b/examples/research_projects/rdm/pipeline_rdm.py
@@ -78,7 +78,7 @@ def __init__(
feature_extractor=feature_extractor,
)
# Copy from statement here and all the methods we take from stable_diffusion_pipeline
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.retriever = retriever
diff --git a/examples/research_projects/realfill/requirements.txt b/examples/research_projects/realfill/requirements.txt
index 8fbaf908a2c8..c45334be97f9 100644
--- a/examples/research_projects/realfill/requirements.txt
+++ b/examples/research_projects/realfill/requirements.txt
@@ -6,4 +6,4 @@ torch==2.2.0
torchvision>=0.16
ftfy==6.1.1
tensorboard==2.14.0
-Jinja2==3.1.4
+Jinja2==3.1.6
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
index 5f7ca2262dcc..26caba5a42c1 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
index 663dbbf99473..410cd74a5b7b 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
index d16780131139..c02a59a0077a 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1475,7 +1475,6 @@ def load_model_hook(models, input_dir):
optimizer = optimizer_class(
params_to_optimize,
- lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
index d3bf95305dad..2ca555889cf9 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
index a4b4d69bb892..3e6199a09a55 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
index bab86bf21a76..abc439912664 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
index a056bcfc8cb1..4738e39e832e 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/sd3_lora_colab/compute_embeddings.py b/examples/research_projects/sd3_lora_colab/compute_embeddings.py
index 5014752ffe34..6571f265c702 100644
--- a/examples/research_projects/sd3_lora_colab/compute_embeddings.py
+++ b/examples/research_projects/sd3_lora_colab/compute_embeddings.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
index 163ff8f08931..f5bee58d4534 100644
--- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
+++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -765,7 +765,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
diff --git a/examples/research_projects/vae/vae_roundtrip.py b/examples/research_projects/vae/vae_roundtrip.py
index 65c2b43a9bde..8388a352b2f2 100644
--- a/examples/research_projects/vae/vae_roundtrip.py
+++ b/examples/research_projects/vae/vae_roundtrip.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/research_projects/wuerstchen/text_to_image/README.md
similarity index 100%
rename from examples/wuerstchen/text_to_image/README.md
rename to examples/research_projects/wuerstchen/text_to_image/README.md
diff --git a/examples/wuerstchen/text_to_image/__init__.py b/examples/research_projects/wuerstchen/text_to_image/__init__.py
similarity index 100%
rename from examples/wuerstchen/text_to_image/__init__.py
rename to examples/research_projects/wuerstchen/text_to_image/__init__.py
diff --git a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/examples/research_projects/wuerstchen/text_to_image/modeling_efficient_net_encoder.py
similarity index 100%
rename from examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py
rename to examples/research_projects/wuerstchen/text_to_image/modeling_efficient_net_encoder.py
diff --git a/examples/wuerstchen/text_to_image/requirements.txt b/examples/research_projects/wuerstchen/text_to_image/requirements.txt
similarity index 100%
rename from examples/wuerstchen/text_to_image/requirements.txt
rename to examples/research_projects/wuerstchen/text_to_image/requirements.txt
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
similarity index 99%
rename from examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
rename to examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index b4c9a44bb5b2..9e2302f1b1ba 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -50,7 +50,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
similarity index 99%
rename from examples/wuerstchen/text_to_image/train_text_to_image_prior.py
rename to examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
index eba8de69203a..83647097d28a 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/server/README.md b/examples/server/README.md
new file mode 100644
index 000000000000..8ad0ed3cbe6a
--- /dev/null
+++ b/examples/server/README.md
@@ -0,0 +1,61 @@
+
+# Create a server
+
+Diffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.
+
+This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.
+
+
+Start by navigating to the `examples/server` folder and installing all of the dependencies.
+
+```py
+pip install .
+pip install -f requirements.txt
+```
+
+Launch the server with the following command.
+
+```py
+python server.py
+```
+
+The server is accessed at http://localhost:8000. You can curl this model with the following command.
+```
+curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
+```
+
+If you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.
+
+```
+uv pip compile requirements.in -o requirements.txt
+```
+
+
+The server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.
+```py
+@app.post("/v1/images/generations")
+async def generate_image(image_input: TextToImageInput):
+ try:
+ loop = asyncio.get_event_loop()
+ scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
+ pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
+ generator = torch.Generator(device="cuda")
+ generator.manual_seed(random.randint(0, 10000000))
+ output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
+ logger.info(f"output: {output}")
+ image_url = save_image(output.images[0])
+ return {"data": [{"url": image_url}]}
+ except Exception as e:
+ if isinstance(e, HTTPException):
+ raise e
+ elif hasattr(e, 'message'):
+ raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
+ raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
+```
+The `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.
+```py
+output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
+```
+At this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.
+
+Another important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.
diff --git a/examples/server/requirements.in b/examples/server/requirements.in
new file mode 100644
index 000000000000..b49b285a8fc8
--- /dev/null
+++ b/examples/server/requirements.in
@@ -0,0 +1,9 @@
+torch~=2.4.0
+transformers==4.46.1
+sentencepiece
+aiohttp
+py-consul
+prometheus_client >= 0.18.0
+prometheus-fastapi-instrumentator >= 7.0.0
+fastapi
+uvicorn
\ No newline at end of file
diff --git a/examples/server/requirements.txt b/examples/server/requirements.txt
new file mode 100644
index 000000000000..065a381f0c9b
--- /dev/null
+++ b/examples/server/requirements.txt
@@ -0,0 +1,124 @@
+# This file was autogenerated by uv via the following command:
+# uv pip compile requirements.in -o requirements.txt
+aiohappyeyeballs==2.4.3
+ # via aiohttp
+aiohttp==3.10.10
+ # via -r requirements.in
+aiosignal==1.3.1
+ # via aiohttp
+annotated-types==0.7.0
+ # via pydantic
+anyio==4.6.2.post1
+ # via starlette
+attrs==24.2.0
+ # via aiohttp
+certifi==2024.8.30
+ # via requests
+charset-normalizer==3.4.0
+ # via requests
+click==8.1.7
+ # via uvicorn
+fastapi==0.115.3
+ # via -r requirements.in
+filelock==3.16.1
+ # via
+ # huggingface-hub
+ # torch
+ # transformers
+frozenlist==1.5.0
+ # via
+ # aiohttp
+ # aiosignal
+fsspec==2024.10.0
+ # via
+ # huggingface-hub
+ # torch
+h11==0.14.0
+ # via uvicorn
+huggingface-hub==0.26.1
+ # via
+ # tokenizers
+ # transformers
+idna==3.10
+ # via
+ # anyio
+ # requests
+ # yarl
+jinja2==3.1.4
+ # via torch
+markupsafe==3.0.2
+ # via jinja2
+mpmath==1.3.0
+ # via sympy
+multidict==6.1.0
+ # via
+ # aiohttp
+ # yarl
+networkx==3.4.2
+ # via torch
+numpy==2.1.2
+ # via transformers
+packaging==24.1
+ # via
+ # huggingface-hub
+ # transformers
+prometheus-client==0.21.0
+ # via
+ # -r requirements.in
+ # prometheus-fastapi-instrumentator
+prometheus-fastapi-instrumentator==7.0.0
+ # via -r requirements.in
+propcache==0.2.0
+ # via yarl
+py-consul==1.5.3
+ # via -r requirements.in
+pydantic==2.9.2
+ # via fastapi
+pydantic-core==2.23.4
+ # via pydantic
+pyyaml==6.0.2
+ # via
+ # huggingface-hub
+ # transformers
+regex==2024.9.11
+ # via transformers
+requests==2.32.3
+ # via
+ # huggingface-hub
+ # py-consul
+ # transformers
+safetensors==0.4.5
+ # via transformers
+sentencepiece==0.2.0
+ # via -r requirements.in
+sniffio==1.3.1
+ # via anyio
+starlette==0.41.0
+ # via
+ # fastapi
+ # prometheus-fastapi-instrumentator
+sympy==1.13.3
+ # via torch
+tokenizers==0.20.1
+ # via transformers
+torch==2.4.1
+ # via -r requirements.in
+tqdm==4.66.5
+ # via
+ # huggingface-hub
+ # transformers
+transformers==4.46.1
+ # via -r requirements.in
+typing-extensions==4.12.2
+ # via
+ # fastapi
+ # huggingface-hub
+ # pydantic
+ # pydantic-core
+ # torch
+urllib3==2.2.3
+ # via requests
+uvicorn==0.32.0
+ # via -r requirements.in
+yarl==1.16.0
+ # via aiohttp
diff --git a/examples/server/server.py b/examples/server/server.py
new file mode 100644
index 000000000000..f8c9bd60d4bf
--- /dev/null
+++ b/examples/server/server.py
@@ -0,0 +1,133 @@
+import asyncio
+import logging
+import os
+import random
+import tempfile
+import traceback
+import uuid
+
+import aiohttp
+import torch
+from fastapi import FastAPI, HTTPException
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.staticfiles import StaticFiles
+from pydantic import BaseModel
+
+from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline
+
+
+logger = logging.getLogger(__name__)
+
+
+class TextToImageInput(BaseModel):
+ model: str
+ prompt: str
+ size: str | None = None
+ n: int | None = None
+
+
+class HttpClient:
+ session: aiohttp.ClientSession = None
+
+ def start(self):
+ self.session = aiohttp.ClientSession()
+
+ async def stop(self):
+ await self.session.close()
+ self.session = None
+
+ def __call__(self) -> aiohttp.ClientSession:
+ assert self.session is not None
+ return self.session
+
+
+class TextToImagePipeline:
+ pipeline: StableDiffusion3Pipeline = None
+ device: str = None
+
+ def start(self):
+ if torch.cuda.is_available():
+ model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large")
+ logger.info("Loading CUDA")
+ self.device = "cuda"
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ ).to(device=self.device)
+ elif torch.backends.mps.is_available():
+ model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium")
+ logger.info("Loading MPS for Mac M Series")
+ self.device = "mps"
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ ).to(device=self.device)
+ else:
+ raise Exception("No CUDA or MPS device available")
+
+
+app = FastAPI()
+service_url = os.getenv("SERVICE_URL", "http://localhost:8000")
+image_dir = os.path.join(tempfile.gettempdir(), "images")
+if not os.path.exists(image_dir):
+ os.makedirs(image_dir)
+app.mount("/images", StaticFiles(directory=image_dir), name="images")
+http_client = HttpClient()
+shared_pipeline = TextToImagePipeline()
+
+# Configure CORS settings
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"], # Allows all origins
+ allow_credentials=True,
+ allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc.
+ allow_headers=["*"], # Allows all headers
+)
+
+
+@app.on_event("startup")
+def startup():
+ http_client.start()
+ shared_pipeline.start()
+
+
+def save_image(image):
+ filename = "draw" + str(uuid.uuid4()).split("-")[0] + ".png"
+ image_path = os.path.join(image_dir, filename)
+ # write image to disk at image_path
+ logger.info(f"Saving image to {image_path}")
+ image.save(image_path)
+ return os.path.join(service_url, "images", filename)
+
+
+@app.get("/")
+@app.post("/")
+@app.options("/")
+async def base():
+ return "Welcome to Diffusers! Where you can use diffusion models to generate images"
+
+
+@app.post("/v1/images/generations")
+async def generate_image(image_input: TextToImageInput):
+ try:
+ loop = asyncio.get_event_loop()
+ scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
+ pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
+ generator = torch.Generator(device=shared_pipeline.device)
+ generator.manual_seed(random.randint(0, 10000000))
+ output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator=generator))
+ logger.info(f"output: {output}")
+ image_url = save_image(output.images[0])
+ return {"data": [{"url": image_url}]}
+ except Exception as e:
+ if isinstance(e, HTTPException):
+ raise e
+ elif hasattr(e, "message"):
+ raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
+ raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
+
+
+if __name__ == "__main__":
+ import uvicorn
+
+ uvicorn.run(app, host="0.0.0.0", port=8000)
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index 02a064fa81ed..a34ecf17eb30 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -60,7 +60,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
@@ -141,9 +141,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
- formatted_images = []
-
- formatted_images.append(np.asarray(validation_image))
+ formatted_images = [np.asarray(validation_image)]
for image in images:
formatted_images.append(np.asarray(image))
diff --git a/examples/text_to_image/test_text_to_image.py b/examples/text_to_image/test_text_to_image.py
index 6231a89b1d1d..7a599aeb351d 100644
--- a/examples/text_to_image/test_text_to_image.py
+++ b/examples/text_to_image/test_text_to_image.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/text_to_image/test_text_to_image_lora.py b/examples/text_to_image/test_text_to_image_lora.py
index 4604b9f5210c..2406515c36d2 100644
--- a/examples/text_to_image/test_text_to_image_lora.py
+++ b/examples/text_to_image/test_text_to_image_lora.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index 684bf352a6c1..adfb7b74477f 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -365,8 +365,8 @@ def parse_args():
"--dream_training",
action="store_true",
help=(
- "Use the DREAM training method, which makes training more efficient and accurate at the ",
- "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
+ "Use the DREAM training method, which makes training more efficient and accurate at the "
+ "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210"
),
)
parser.add_argument(
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index 4a80067d693d..4564c1d16f45 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -49,7 +49,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 379519b4c812..82c395c685f8 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -56,7 +56,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -515,10 +515,6 @@ def main():
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
- # Freeze the unet parameters before adding adapters
- for param in unet.parameters():
- param.requires_grad_(False)
-
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index fe098c8638d5..2061f0c6775b 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -68,7 +68,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -137,7 +137,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index bcf0fa9eb0ac..29da1f2efbaa 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -55,7 +55,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -483,7 +483,6 @@ def parse_args(input_args=None):
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
-
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
@@ -824,9 +823,7 @@ def load_model_hook(models, input_dir):
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
- args.dataset_name,
- args.dataset_config_name,
- cache_dir=args.cache_dir,
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
)
else:
data_files = {}
@@ -922,7 +919,7 @@ def preprocess_train(examples):
# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
- new_fingerprint_for_vae = Hasher.hash(vae_path)
+ new_fingerprint_for_vae = Hasher.hash((vae_path, args))
train_dataset_with_embeddings = train_dataset.map(
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
)
@@ -1244,7 +1241,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
pipeline.set_progress_bar_config(disable=True)
# run inference
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = (
+ torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ if args.seed is not None
+ else None
+ )
pipeline_args = {"prompt": args.validation_prompt}
with autocast_ctx:
@@ -1308,7 +1309,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device)
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ generator = (
+ torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ )
with autocast_ctx:
images = [
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 6b710531836b..757a12045f10 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -81,7 +81,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index ee7b1580d145..3ee675e76bbb 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index a4629f0f43d6..11463943c448 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -76,7 +76,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 1f5e1de240cb..45b674cb5894 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -29,7 +29,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py
index 664a7f7365b0..aa5d4c67b642 100644
--- a/examples/vqgan/test_vqgan.py
+++ b/examples/vqgan/test_vqgan.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py
index d16dce921896..992722fa7a78 100644
--- a/examples/vqgan/train_vqgan.py
+++ b/examples/vqgan/train_vqgan.py
@@ -50,7 +50,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.33.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/scripts/change_naming_configs_and_checkpoints.py b/scripts/change_naming_configs_and_checkpoints.py
index adc1605e95b3..4220901c13bf 100644
--- a/scripts/change_naming_configs_and_checkpoints.py
+++ b/scripts/change_naming_configs_and_checkpoints.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_blipdiffusion_to_diffusers.py b/scripts/convert_blipdiffusion_to_diffusers.py
index 03cf67e5476b..2c286ea0fdc7 100644
--- a/scripts/convert_blipdiffusion_to_diffusers.py
+++ b/scripts/convert_blipdiffusion_to_diffusers.py
@@ -303,10 +303,11 @@ def save_blip_diffusion_model(model, args):
qformer = get_qformer(model)
qformer.eval()
- text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
- vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
-
- unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ text_encoder = ContextCLIPTextModel.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="text_encoder"
+ )
+ vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler(
@@ -316,7 +317,7 @@ def save_blip_diffusion_model(model, args):
set_alpha_to_one=False,
skip_prk_steps=True,
)
- tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
+ tokenizer = CLIPTokenizer.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="tokenizer")
image_processor = BlipImageProcessor()
blip_diffusion = BlipDiffusionPipeline(
tokenizer=tokenizer,
diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py
index 4343eaf34038..7eeed240c4de 100644
--- a/scripts/convert_cogvideox_to_diffusers.py
+++ b/scripts/convert_cogvideox_to_diffusers.py
@@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
+ "ofs_embed.0": "ofs_embedding.linear_1",
+ "ofs_embed.2": "ofs_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
@@ -140,6 +142,7 @@ def convert_transformer(
use_rotary_positional_embeddings: bool,
i2v: bool,
dtype: torch.dtype,
+ init_kwargs: Dict[str, Any],
):
PREFIX_KEY = "model.diffusion_model."
@@ -149,7 +152,9 @@ def convert_transformer(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
- use_learned_positional_embeddings=i2v,
+ ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
+ use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
+ **init_kwargs,
).to(dtype=dtype)
for key in list(original_state_dict.keys()):
@@ -163,13 +168,18 @@ def convert_transformer(
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
+
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
-def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
+def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
+ init_kwargs = {"scaling_factor": scaling_factor}
+ if version == "1.5":
+ init_kwargs.update({"invert_scale_latents": True})
+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
- vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
+ vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
return vae
+def get_transformer_init_kwargs(version: str):
+ if version == "1.0":
+ vae_scale_factor_spatial = 8
+ init_kwargs = {
+ "patch_size": 2,
+ "patch_size_t": None,
+ "patch_bias": True,
+ "sample_height": 480 // vae_scale_factor_spatial,
+ "sample_width": 720 // vae_scale_factor_spatial,
+ "sample_frames": 49,
+ }
+
+ elif version == "1.5":
+ vae_scale_factor_spatial = 8
+ init_kwargs = {
+ "patch_size": 2,
+ "patch_size_t": 2,
+ "patch_bias": False,
+ "sample_height": 300,
+ "sample_width": 300,
+ "sample_frames": 81,
+ }
+ else:
+ raise ValueError("Unsupported version of CogVideoX.")
+
+ return init_kwargs
+
+
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -202,6 +240,12 @@ def get_args():
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
+ parser.add_argument(
+ "--typecast_text_encoder",
+ action="store_true",
+ default=False,
+ help="Whether or not to apply fp16/bf16 precision to text_encoder",
+ )
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
@@ -214,7 +258,18 @@ def get_args():
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
- parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
+ parser.add_argument(
+ "--i2v",
+ action="store_true",
+ default=False,
+ help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
+ )
+ parser.add_argument(
+ "--version",
+ choices=["1.0", "1.5"],
+ default="1.0",
+ help="Which version of CogVideoX to use for initializing default modeling parameters.",
+ )
return parser.parse_args()
@@ -230,6 +285,7 @@ def get_args():
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None:
+ init_kwargs = get_transformer_init_kwargs(args.version)
transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
@@ -237,14 +293,19 @@ def get_args():
args.use_rotary_positional_embeddings,
args.i2v,
dtype,
+ init_kwargs,
)
if args.vae_ckpt_path is not None:
- vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
+ # Keep VAE in float32 for better quality
+ vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+ if args.typecast_text_encoder:
+ text_encoder = text_encoder.to(dtype=dtype)
+
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
@@ -276,11 +337,6 @@ def get_args():
scheduler=scheduler,
)
- if args.fp16:
- pipe = pipe.to(dtype=torch.float16)
- if args.bf16:
- pipe = pipe.to(dtype=torch.bfloat16)
-
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
diff --git a/scripts/convert_cogview3_to_diffusers.py b/scripts/convert_cogview3_to_diffusers.py
index 48cda2084240..605555ebdbef 100644
--- a/scripts/convert_cogview3_to_diffusers.py
+++ b/scripts/convert_cogview3_to_diffusers.py
@@ -36,7 +36,7 @@
from diffusers.utils.import_utils import is_accelerate_available
-CTX = init_empty_weights if is_accelerate_available else nullcontext
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
TOKENIZER_MAX_LENGTH = 224
diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py
new file mode 100644
index 000000000000..b6d01c797aeb
--- /dev/null
+++ b/scripts/convert_cogview4_to_diffusers.py
@@ -0,0 +1,254 @@
+"""
+Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format.
+(deprecated Since 2025-02-07 and will remove it in later CogView4 version)
+
+This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
+with the Diffusers library.
+
+Example usage:
+ python scripts/convert_cogview4_to_diffusers.py \
+ --transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
+ --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
+ --output_path "THUDM/CogView4-6B" \
+ --dtype "bf16"
+
+Arguments:
+ --transformer_checkpoint_path: Path to Transformer state dict.
+ --vae_checkpoint_path: Path to VAE state dict.
+ --output_path: The path to save the converted model.
+ --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
+ --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
+
+ Default is "bf16" because CogView4 uses bfloat16 for Training.
+
+Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
+"""
+
+import argparse
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+from transformers import GlmForCausalLM, PreTrainedTokenizerFast
+
+from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
+parser.add_argument("--vae_checkpoint_path", default=None, type=str)
+parser.add_argument("--output_path", required=True, type=str)
+parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
+parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
+parser.add_argument("--dtype", type=str, default="bf16")
+
+args = parser.parse_args()
+
+
+# this is specific to `AdaLayerNormContinuous`:
+# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
+def swap_scale_shift(weight, dim):
+ """
+ Swap the scale and shift components in the weight tensor.
+
+ Args:
+ weight (torch.Tensor): The original weight tensor.
+ dim (int): The dimension along which to split.
+
+ Returns:
+ torch.Tensor: The modified weight tensor with scale and shift swapped.
+ """
+ shift, scale = weight.chunk(2, dim=dim)
+ new_weight = torch.cat([scale, shift], dim=dim)
+ return new_weight
+
+
+def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")
+ original_state_dict = original_state_dict["module"]
+ original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
+
+ new_state_dict = {}
+
+ # Convert patch_embed
+ new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
+ new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
+ new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
+ new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
+
+ # Convert time_condition_embed
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_embed.0.weight"
+ )
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "time_embed.0.bias"
+ )
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_embed.2.weight"
+ )
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "time_embed.2.bias"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
+ "label_emb.0.0.weight"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
+ "label_emb.0.0.bias"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
+ "label_emb.0.2.weight"
+ )
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
+ "label_emb.0.2.bias"
+ )
+
+ # Convert transformer blocks, for cogview4 is 28 blocks
+ for i in range(28):
+ block_prefix = f"transformer_blocks.{i}."
+ old_prefix = f"transformer.layers.{i}."
+ adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
+ new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
+ new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
+
+ qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
+ qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
+
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
+
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
+ old_prefix + "attention.dense.weight"
+ )
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
+ old_prefix + "attention.dense.bias"
+ )
+
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_h_to_4h.weight"
+ )
+ new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_h_to_4h.bias"
+ )
+ new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
+ old_prefix + "mlp.dense_4h_to_h.weight"
+ )
+ new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
+
+ # Convert final norm and projection
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
+ )
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
+ )
+ new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
+ new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
+
+ return new_state_dict
+
+
+def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+ return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
+
+
+def main(args):
+ if args.dtype == "fp16":
+ dtype = torch.float16
+ elif args.dtype == "bf16":
+ dtype = torch.bfloat16
+ elif args.dtype == "fp32":
+ dtype = torch.float32
+ else:
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
+
+ transformer = None
+ vae = None
+
+ if args.transformer_checkpoint_path is not None:
+ converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
+ args.transformer_checkpoint_path
+ )
+ transformer = CogView4Transformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ num_layers=28,
+ attention_head_dim=128,
+ num_attention_heads=32,
+ out_channels=16,
+ text_embed_dim=4096,
+ time_embed_dim=512,
+ condition_dim=256,
+ pos_embed_max_size=128,
+ )
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+ if dtype is not None:
+ # Original checkpoint data type will be preserved
+ transformer = transformer.to(dtype=dtype)
+
+ if args.vae_checkpoint_path is not None:
+ vae_config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ("DownEncoderBlock2D",) * 4,
+ "up_block_types": ("UpDecoderBlock2D",) * 4,
+ "block_out_channels": (128, 512, 1024, 1024),
+ "layers_per_block": 3,
+ "act_fn": "silu",
+ "latent_channels": 16,
+ "norm_num_groups": 32,
+ "sample_size": 1024,
+ "scaling_factor": 1.0,
+ "shift_factor": 0.0,
+ "force_upcast": True,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "mid_block_add_attention": False,
+ }
+ converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ if dtype is not None:
+ vae = vae.to(dtype=dtype)
+
+ text_encoder_id = "THUDM/glm-4-9b-hf"
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
+ text_encoder = GlmForCausalLM.from_pretrained(
+ text_encoder_id,
+ cache_dir=args.text_encoder_cache_dir,
+ torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
+ )
+
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
+ )
+
+ pipe = CogView4Pipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ # This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
+ # save some memory used for model loading.
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py
new file mode 100644
index 000000000000..8faeccb13888
--- /dev/null
+++ b/scripts/convert_cogview4_to_diffusers_megatron.py
@@ -0,0 +1,384 @@
+"""
+Convert a CogView4 checkpoint from Megatron to the Diffusers format.
+
+Example usage:
+ python scripts/convert_cogview4_to_diffusers.py \
+ --transformer_checkpoint_path 'your path/cogview4_6b/mp_rank_00/model_optim_rng.pt' \
+ --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
+ --output_path "THUDM/CogView4-6B" \
+ --dtype "bf16"
+
+Arguments:
+ --transformer_checkpoint_path: Path to Transformer state dict.
+ --vae_checkpoint_path: Path to VAE state dict.
+ --output_path: The path to save the converted model.
+ --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
+ --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used.
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
+
+ Default is "bf16" because CogView4 uses bfloat16 for training.
+
+Note: You must provide either --transformer_checkpoint_path or --vae_checkpoint_path.
+"""
+
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import GlmModel, PreTrainedTokenizerFast
+
+from diffusers import (
+ AutoencoderKL,
+ CogView4ControlPipeline,
+ CogView4Pipeline,
+ CogView4Transformer2DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--transformer_checkpoint_path",
+ default=None,
+ type=str,
+ help="Path to Megatron (not SAT) Transformer checkpoint, e.g., 'model_optim_rng.pt'.",
+)
+parser.add_argument(
+ "--vae_checkpoint_path",
+ default=None,
+ type=str,
+ help="(Optional) Path to VAE checkpoint, e.g., 'imagekl_ch16.pt'.",
+)
+parser.add_argument(
+ "--output_path",
+ required=True,
+ type=str,
+ help="Directory to save the final Diffusers format pipeline.",
+)
+parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ default=False,
+ help="Whether to push the converted model to the HuggingFace Hub.",
+)
+parser.add_argument(
+ "--text_encoder_cache_dir",
+ type=str,
+ default=None,
+ help="Specify the cache directory for the text encoder.",
+)
+parser.add_argument(
+ "--dtype",
+ type=str,
+ default="bf16",
+ choices=["fp16", "bf16", "fp32"],
+ help="Data type to save the model in.",
+)
+
+parser.add_argument(
+ "--num_layers",
+ type=int,
+ default=28,
+ help="Number of Transformer layers (e.g., 28, 48...).",
+)
+parser.add_argument(
+ "--num_heads",
+ type=int,
+ default=32,
+ help="Number of attention heads.",
+)
+parser.add_argument(
+ "--hidden_size",
+ type=int,
+ default=4096,
+ help="Transformer hidden dimension size.",
+)
+parser.add_argument(
+ "--attention_head_dim",
+ type=int,
+ default=128,
+ help="Dimension of each attention head.",
+)
+parser.add_argument(
+ "--time_embed_dim",
+ type=int,
+ default=512,
+ help="Dimension of time embeddings.",
+)
+parser.add_argument(
+ "--condition_dim",
+ type=int,
+ default=256,
+ help="Dimension of condition embeddings.",
+)
+parser.add_argument(
+ "--pos_embed_max_size",
+ type=int,
+ default=128,
+ help="Maximum size for positional embeddings.",
+)
+parser.add_argument(
+ "--control",
+ action="store_true",
+ default=False,
+ help="Whether to use control model.",
+)
+
+args = parser.parse_args()
+
+
+def swap_scale_shift(weight, dim):
+ """
+ Swap the scale and shift components in the weight tensor.
+
+ Args:
+ weight (torch.Tensor): The original weight tensor.
+ dim (int): The dimension along which to split.
+
+ Returns:
+ torch.Tensor: The modified weight tensor with scale and shift swapped.
+ """
+ shift, scale = weight.chunk(2, dim=dim)
+ new_weight = torch.cat([scale, shift], dim=dim)
+ return new_weight
+
+
+def convert_megatron_transformer_checkpoint_to_diffusers(
+ ckpt_path: str,
+ num_layers: int,
+ num_heads: int,
+ hidden_size: int,
+):
+ """
+ Convert a Megatron Transformer checkpoint to Diffusers format.
+
+ Args:
+ ckpt_path (str): Path to the Megatron Transformer checkpoint.
+ num_layers (int): Number of Transformer layers.
+ num_heads (int): Number of attention heads.
+ hidden_size (int): Hidden size of the Transformer.
+
+ Returns:
+ dict: The converted state dictionary compatible with Diffusers.
+ """
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
+ mega = ckpt["model"]
+
+ new_state_dict = {}
+
+ # Patch Embedding
+ new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
+ hidden_size, 128 if args.control else 64
+ )
+ new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
+ new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
+ new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
+
+ # Time Condition Embedding
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = mega[
+ "time_embedding.time_embed.0.weight"
+ ]
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = mega["time_embedding.time_embed.0.bias"]
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = mega[
+ "time_embedding.time_embed.2.weight"
+ ]
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = mega["time_embedding.time_embed.2.bias"]
+
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = mega[
+ "label_embedding.label_embed.0.weight"
+ ]
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = mega[
+ "label_embedding.label_embed.0.bias"
+ ]
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = mega[
+ "label_embedding.label_embed.2.weight"
+ ]
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = mega[
+ "label_embedding.label_embed.2.bias"
+ ]
+
+ # Convert each Transformer layer
+ for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"):
+ block_prefix = f"transformer_blocks.{i}."
+
+ # AdaLayerNorm
+ new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
+ new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
+ qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
+ qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
+
+ # Reshape to match SAT logic
+ qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)
+ qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)
+
+ qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)
+ qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)
+
+ # Assign to Diffusers keys
+ q, k, v = torch.chunk(qkv_weight, 3, dim=0)
+ qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)
+
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = qb
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = kb
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = vb
+
+ # Attention Output
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
+ f"decoder.layers.{i}.self_attention.linear_proj.weight"
+ ]
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
+ f"decoder.layers.{i}.self_attention.linear_proj.bias"
+ ]
+
+ # MLP
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.weight"]
+ new_state_dict[block_prefix + "ff.net.0.proj.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.bias"]
+ new_state_dict[block_prefix + "ff.net.2.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.weight"]
+ new_state_dict[block_prefix + "ff.net.2.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.bias"]
+
+ # Final Layers
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(mega["adaln_final.weight"], dim=0)
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(mega["adaln_final.bias"], dim=0)
+ new_state_dict["proj_out.weight"] = mega["output_projector.weight"]
+ new_state_dict["proj_out.bias"] = mega["output_projector.bias"]
+
+ return new_state_dict
+
+
+def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
+ """
+ Convert a CogView4 VAE checkpoint to Diffusers format.
+
+ Args:
+ ckpt_path (str): Path to the VAE checkpoint.
+ vae_config (dict): Configuration dictionary for the VAE.
+
+ Returns:
+ dict: The converted VAE state dictionary compatible with Diffusers.
+ """
+ original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
+ return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
+
+
+def main(args):
+ """
+ Main function to convert CogView4 checkpoints to Diffusers format.
+
+ Args:
+ args (argparse.Namespace): Parsed command-line arguments.
+ """
+ # Determine the desired data type
+ if args.dtype == "fp16":
+ dtype = torch.float16
+ elif args.dtype == "bf16":
+ dtype = torch.bfloat16
+ elif args.dtype == "fp32":
+ dtype = torch.float32
+ else:
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
+
+ transformer = None
+ vae = None
+
+ # Convert Transformer checkpoint if provided
+ if args.transformer_checkpoint_path is not None:
+ converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(
+ ckpt_path=args.transformer_checkpoint_path,
+ num_layers=args.num_layers,
+ num_heads=args.num_heads,
+ hidden_size=args.hidden_size,
+ )
+ transformer = CogView4Transformer2DModel(
+ patch_size=2,
+ in_channels=32 if args.control else 16,
+ num_layers=args.num_layers,
+ attention_head_dim=args.attention_head_dim,
+ num_attention_heads=args.num_heads,
+ out_channels=16,
+ text_embed_dim=args.hidden_size,
+ time_embed_dim=args.time_embed_dim,
+ condition_dim=args.condition_dim,
+ pos_embed_max_size=args.pos_embed_max_size,
+ )
+
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ # Convert to the specified dtype
+ if dtype is not None:
+ transformer = transformer.to(dtype=dtype)
+
+ # Convert VAE checkpoint if provided
+ if args.vae_checkpoint_path is not None:
+ vae_config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ("DownEncoderBlock2D",) * 4,
+ "up_block_types": ("UpDecoderBlock2D",) * 4,
+ "block_out_channels": (128, 512, 1024, 1024),
+ "layers_per_block": 3,
+ "act_fn": "silu",
+ "latent_channels": 16,
+ "norm_num_groups": 32,
+ "sample_size": 1024,
+ "scaling_factor": 1.0,
+ "shift_factor": 0.0,
+ "force_upcast": True,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "mid_block_add_attention": False,
+ }
+ converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ if dtype is not None:
+ vae = vae.to(dtype=dtype)
+
+ # Load the text encoder and tokenizer
+ text_encoder_id = "THUDM/glm-4-9b-hf"
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
+ text_encoder = GlmModel.from_pretrained(
+ text_encoder_id,
+ cache_dir=args.text_encoder_cache_dir,
+ torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
+ )
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ # Initialize the scheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
+ )
+
+ # Create the pipeline
+ if args.control:
+ pipe = CogView4ControlPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ else:
+ pipe = CogView4Pipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ # Save the converted pipeline
+ pipe.save_pretrained(
+ args.output_path,
+ safe_serialization=True,
+ max_shard_size="5GB",
+ push_to_hub=args.push_to_hub,
+ )
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_consistency_decoder.py b/scripts/convert_consistency_decoder.py
index 0cb5fc50dd60..629c784c095a 100644
--- a/scripts/convert_consistency_decoder.py
+++ b/scripts/convert_consistency_decoder.py
@@ -73,7 +73,7 @@ def _download(url: str, root: str):
loop.update(len(buffer))
if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
- raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not match")
return download_target
diff --git a/scripts/convert_dcae_to_diffusers.py b/scripts/convert_dcae_to_diffusers.py
new file mode 100644
index 000000000000..15f79a8154e6
--- /dev/null
+++ b/scripts/convert_dcae_to_diffusers.py
@@ -0,0 +1,323 @@
+import argparse
+from typing import Any, Dict
+
+import torch
+from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file
+
+from diffusers import AutoencoderDC
+
+
+def remap_qkv_(key: str, state_dict: Dict[str, Any]):
+ qkv = state_dict.pop(key)
+ q, k, v = torch.chunk(qkv, 3, dim=0)
+ parent_module, _, _ = key.rpartition(".qkv.conv.weight")
+ state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
+ state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
+ state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
+
+
+def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
+ parent_module, _, _ = key.rpartition(".proj.conv.weight")
+ state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
+
+
+AE_KEYS_RENAME_DICT = {
+ # common
+ "main.": "",
+ "op_list.": "",
+ "context_module": "attn",
+ "local_module": "conv_out",
+ # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
+ # If there were more scales, there would be more layers, so a loop would be better to handle this
+ "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
+ "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
+ "depth_conv.conv": "conv_depth",
+ "inverted_conv.conv": "conv_inverted",
+ "point_conv.conv": "conv_point",
+ "point_conv.norm": "norm",
+ "conv.conv.": "conv.",
+ "conv1.conv": "conv1",
+ "conv2.conv": "conv2",
+ "conv2.norm": "norm",
+ "proj.norm": "norm_out",
+ # encoder
+ "encoder.project_in.conv": "encoder.conv_in",
+ "encoder.project_out.0.conv": "encoder.conv_out",
+ "encoder.stages": "encoder.down_blocks",
+ # decoder
+ "decoder.project_in.conv": "decoder.conv_in",
+ "decoder.project_out.0": "decoder.norm_out",
+ "decoder.project_out.2.conv": "decoder.conv_out",
+ "decoder.stages": "decoder.up_blocks",
+}
+
+AE_F32C32_KEYS = {
+ # encoder
+ "encoder.project_in.conv": "encoder.conv_in.conv",
+ # decoder
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
+}
+
+AE_F64C128_KEYS = {
+ # encoder
+ "encoder.project_in.conv": "encoder.conv_in.conv",
+ # decoder
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
+}
+
+AE_F128C512_KEYS = {
+ # encoder
+ "encoder.project_in.conv": "encoder.conv_in.conv",
+ # decoder
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
+}
+
+AE_SPECIAL_KEYS_REMAP = {
+ "qkv.conv.weight": remap_qkv_,
+ "proj.conv.weight": remap_proj_conv_,
+}
+
+
+def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
+ state_dict = saved_dict
+ if "model" in saved_dict.keys():
+ state_dict = state_dict["model"]
+ if "module" in saved_dict.keys():
+ state_dict = state_dict["module"]
+ if "state_dict" in saved_dict.keys():
+ state_dict = state_dict["state_dict"]
+ return state_dict
+
+
+def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def convert_ae(config_name: str, dtype: torch.dtype):
+ config = get_ae_config(config_name)
+ hub_id = f"mit-han-lab/{config_name}"
+ ckpt_path = hf_hub_download(hub_id, "model.safetensors")
+ original_state_dict = get_state_dict(load_file(ckpt_path))
+
+ ae = AutoencoderDC(**config).to(dtype=dtype)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ ae.load_state_dict(original_state_dict, strict=True)
+ return ae
+
+
+def get_ae_config(name: str):
+ if name in ["dc-ae-f32c32-sana-1.0"]:
+ config = {
+ "latent_channels": 32,
+ "encoder_block_types": (
+ "ResBlock",
+ "ResBlock",
+ "ResBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ ),
+ "decoder_block_types": (
+ "ResBlock",
+ "ResBlock",
+ "ResBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ ),
+ "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
+ "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
+ "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
+ "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
+ "encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
+ "decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
+ "downsample_block_type": "conv",
+ "upsample_block_type": "interpolate",
+ "decoder_norm_types": "rms_norm",
+ "decoder_act_fns": "silu",
+ "scaling_factor": 0.41407,
+ }
+ elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
+ AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
+ config = {
+ "latent_channels": 32,
+ "encoder_block_types": [
+ "ResBlock",
+ "ResBlock",
+ "ResBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ ],
+ "decoder_block_types": [
+ "ResBlock",
+ "ResBlock",
+ "ResBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ ],
+ "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
+ "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
+ "encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
+ "decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
+ "encoder_qkv_multiscales": ((), (), (), (), (), ()),
+ "decoder_qkv_multiscales": ((), (), (), (), (), ()),
+ "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
+ "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
+ }
+ if name == "dc-ae-f32c32-in-1.0":
+ config["scaling_factor"] = 0.3189
+ elif name == "dc-ae-f32c32-mix-1.0":
+ config["scaling_factor"] = 0.4552
+ elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
+ AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
+ config = {
+ "latent_channels": 128,
+ "encoder_block_types": [
+ "ResBlock",
+ "ResBlock",
+ "ResBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ ],
+ "decoder_block_types": [
+ "ResBlock",
+ "ResBlock",
+ "ResBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ ],
+ "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
+ "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
+ "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
+ "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
+ "encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
+ "decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
+ "decoder_norm_types": [
+ "batch_norm",
+ "batch_norm",
+ "batch_norm",
+ "rms_norm",
+ "rms_norm",
+ "rms_norm",
+ "rms_norm",
+ ],
+ "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
+ }
+ if name == "dc-ae-f64c128-in-1.0":
+ config["scaling_factor"] = 0.2889
+ elif name == "dc-ae-f64c128-mix-1.0":
+ config["scaling_factor"] = 0.4538
+ elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
+ AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
+ config = {
+ "latent_channels": 512,
+ "encoder_block_types": [
+ "ResBlock",
+ "ResBlock",
+ "ResBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ ],
+ "decoder_block_types": [
+ "ResBlock",
+ "ResBlock",
+ "ResBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ "EfficientViTBlock",
+ ],
+ "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
+ "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
+ "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
+ "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
+ "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
+ "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
+ "decoder_norm_types": [
+ "batch_norm",
+ "batch_norm",
+ "batch_norm",
+ "rms_norm",
+ "rms_norm",
+ "rms_norm",
+ "rms_norm",
+ "rms_norm",
+ ],
+ "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
+ }
+ if name == "dc-ae-f128c512-in-1.0":
+ config["scaling_factor"] = 0.4883
+ elif name == "dc-ae-f128c512-mix-1.0":
+ config["scaling_factor"] = 0.3620
+ else:
+ raise ValueError("Invalid config name provided.")
+
+ return config
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--config_name",
+ type=str,
+ default="dc-ae-f32c32-sana-1.0",
+ choices=[
+ "dc-ae-f32c32-sana-1.0",
+ "dc-ae-f32c32-in-1.0",
+ "dc-ae-f32c32-mix-1.0",
+ "dc-ae-f64c128-in-1.0",
+ "dc-ae-f64c128-mix-1.0",
+ "dc-ae-f128c512-in-1.0",
+ "dc-ae-f128c512-mix-1.0",
+ ],
+ help="The DCAE checkpoint to convert",
+ )
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+ parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+VARIANT_MAPPING = {
+ "fp32": None,
+ "fp16": "fp16",
+ "bf16": "bf16",
+}
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ dtype = DTYPE_MAPPING[args.dtype]
+ variant = VARIANT_MAPPING[args.dtype]
+
+ ae = convert_ae(args.config_name, dtype)
+ ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
diff --git a/scripts/convert_flux_to_diffusers.py b/scripts/convert_flux_to_diffusers.py
index 05a1da256d33..fccac70dd855 100644
--- a/scripts/convert_flux_to_diffusers.py
+++ b/scripts/convert_flux_to_diffusers.py
@@ -31,12 +31,14 @@
--vae
"""
-CTX = init_empty_weights if is_accelerate_available else nullcontext
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--in_channels", type=int, default=64)
+parser.add_argument("--out_channels", type=int, default=None)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str)
@@ -279,10 +281,13 @@ def main(args):
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0
+
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
)
- transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
+ transformer = FluxTransformer2DModel(
+ in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
+ )
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
print(
diff --git a/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py b/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py
new file mode 100644
index 000000000000..b701b7fb40b1
--- /dev/null
+++ b/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py
@@ -0,0 +1,97 @@
+import argparse
+from contextlib import nullcontext
+
+import safetensors.torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+
+from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
+
+
+if is_transformers_available():
+ from transformers import CLIPVisionModelWithProjection
+
+ vision = True
+else:
+ vision = False
+
+"""
+python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
+--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
+--filename "flux-ip-adapter.safetensors"
+--output_path "flux-ip-adapter-hf/"
+"""
+
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
+parser.add_argument("--filename", default="flux.safetensors", type=str)
+parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--output_path", type=str)
+parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
+
+args = parser.parse_args()
+
+
+def load_original_checkpoint(args):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
+ converted_state_dict = {}
+
+ # image_proj
+ ## norm
+ converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
+ converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
+ ## proj
+ converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
+ converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"ip_adapter.{i}."
+ # to_k_ip
+ converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
+ )
+ # to_v_ip
+ converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
+ )
+
+ return converted_state_dict
+
+
+def main(args):
+ original_ckpt = load_original_checkpoint(args)
+
+ num_layers = 19
+ converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
+
+ print("Saving Flux IP-Adapter in Diffusers format.")
+ safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
+
+ if vision:
+ model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
+ model.save_pretrained(f"{args.output_path}/image_encoder")
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py
new file mode 100644
index 000000000000..c84809d7f68a
--- /dev/null
+++ b/scripts/convert_hunyuan_video_to_diffusers.py
@@ -0,0 +1,353 @@
+import argparse
+from typing import Any, Dict
+
+import torch
+from accelerate import init_empty_weights
+from transformers import (
+ AutoModel,
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ LlavaForConditionalGeneration,
+)
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideoImageToVideoPipeline,
+ HunyuanVideoPipeline,
+ HunyuanVideoTransformer3DModel,
+)
+
+
+def remap_norm_scale_shift_(key, state_dict):
+ weight = state_dict.pop(key)
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
+
+
+def remap_txt_in_(key, state_dict):
+ def rename_key(key):
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
+ new_key = new_key.replace("txt_in", "context_embedder")
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
+ new_key = new_key.replace("mlp", "ff")
+ return new_key
+
+ if "self_attn_qkv" in key:
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
+ else:
+ state_dict[rename_key(key)] = state_dict.pop(key)
+
+
+def remap_img_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
+
+
+def remap_txt_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
+
+
+def remap_single_transformer_blocks_(key, state_dict):
+ hidden_size = 3072
+
+ if "linear1.weight" in key:
+ linear1_weight = state_dict.pop(key)
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
+ state_dict[f"{new_key}.attn.to_q.weight"] = q
+ state_dict[f"{new_key}.attn.to_k.weight"] = k
+ state_dict[f"{new_key}.attn.to_v.weight"] = v
+ state_dict[f"{new_key}.proj_mlp.weight"] = mlp
+
+ elif "linear1.bias" in key:
+ linear1_bias = state_dict.pop(key)
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
+ state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
+ state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
+ state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
+ state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
+
+ else:
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
+ new_key = new_key.replace("linear2", "proj_out")
+ new_key = new_key.replace("q_norm", "attn.norm_q")
+ new_key = new_key.replace("k_norm", "attn.norm_k")
+ state_dict[new_key] = state_dict.pop(key)
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "img_in": "x_embedder",
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
+ "double_blocks": "transformer_blocks",
+ "img_attn_q_norm": "attn.norm_q",
+ "img_attn_k_norm": "attn.norm_k",
+ "img_attn_proj": "attn.to_out.0",
+ "txt_attn_q_norm": "attn.norm_added_q",
+ "txt_attn_k_norm": "attn.norm_added_k",
+ "txt_attn_proj": "attn.to_add_out",
+ "img_mod.linear": "norm1.linear",
+ "img_norm1": "norm1.norm",
+ "img_norm2": "norm2",
+ "img_mlp": "ff",
+ "txt_mod.linear": "norm1_context.linear",
+ "txt_norm1": "norm1.norm",
+ "txt_norm2": "norm2_context",
+ "txt_mlp": "ff_context",
+ "self_attn_proj": "attn.to_out.0",
+ "modulation.linear": "norm.linear",
+ "pre_norm": "norm.norm",
+ "final_layer.norm_final": "norm_out.norm",
+ "final_layer.linear": "proj_out",
+ "fc1": "net.0.proj",
+ "fc2": "net.2",
+ "input_embedder": "proj_in",
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "txt_in": remap_txt_in_,
+ "img_attn_qkv": remap_img_attn_qkv_,
+ "txt_attn_qkv": remap_txt_attn_qkv_,
+ "single_blocks": remap_single_transformer_blocks_,
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
+}
+
+VAE_KEYS_RENAME_DICT = {}
+
+VAE_SPECIAL_KEYS_REMAP = {}
+
+
+TRANSFORMER_CONFIGS = {
+ "HYVideo-T/2-cfgdistill": {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 24,
+ "attention_head_dim": 128,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "num_refiner_layers": 2,
+ "mlp_ratio": 4.0,
+ "patch_size": 2,
+ "patch_size_t": 1,
+ "qk_norm": "rms_norm",
+ "guidance_embeds": True,
+ "text_embed_dim": 4096,
+ "pooled_projection_dim": 768,
+ "rope_theta": 256.0,
+ "rope_axes_dim": (16, 56, 56),
+ "image_condition_type": None,
+ },
+ "HYVideo-T/2-I2V-33ch": {
+ "in_channels": 16 * 2 + 1,
+ "out_channels": 16,
+ "num_attention_heads": 24,
+ "attention_head_dim": 128,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "num_refiner_layers": 2,
+ "mlp_ratio": 4.0,
+ "patch_size": 2,
+ "patch_size_t": 1,
+ "qk_norm": "rms_norm",
+ "guidance_embeds": False,
+ "text_embed_dim": 4096,
+ "pooled_projection_dim": 768,
+ "rope_theta": 256.0,
+ "rope_axes_dim": (16, 56, 56),
+ "image_condition_type": "latent_concat",
+ },
+ "HYVideo-T/2-I2V-16ch": {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 24,
+ "attention_head_dim": 128,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "num_refiner_layers": 2,
+ "mlp_ratio": 4.0,
+ "patch_size": 2,
+ "patch_size_t": 1,
+ "qk_norm": "rms_norm",
+ "guidance_embeds": True,
+ "text_embed_dim": 4096,
+ "pooled_projection_dim": 768,
+ "rope_theta": 256.0,
+ "rope_axes_dim": (16, 56, 56),
+ "image_condition_type": "token_replace",
+ },
+}
+
+
+def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
+ state_dict = saved_dict
+ if "model" in saved_dict.keys():
+ state_dict = state_dict["model"]
+ if "module" in saved_dict.keys():
+ state_dict = state_dict["module"]
+ if "state_dict" in saved_dict.keys():
+ state_dict = state_dict["state_dict"]
+ return state_dict
+
+
+def convert_transformer(ckpt_path: str, transformer_type: str):
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
+ config = TRANSFORMER_CONFIGS[transformer_type]
+
+ with init_empty_weights():
+ transformer = HunyuanVideoTransformer3DModel(**config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def convert_vae(ckpt_path: str):
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
+
+ with init_empty_weights():
+ vae = AutoencoderKLHunyuanVideo()
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vae.load_state_dict(original_state_dict, strict=True, assign=True)
+ return vae
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
+ )
+ parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
+ parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
+ parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
+ parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
+ parser.add_argument("--save_pipeline", action="store_true")
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+ parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
+ parser.add_argument(
+ "--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
+ )
+ parser.add_argument("--flow_shift", type=float, default=7.0)
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ dtype = DTYPE_MAPPING[args.dtype]
+
+ if args.save_pipeline:
+ assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
+ assert args.text_encoder_path is not None
+ assert args.tokenizer_path is not None
+ assert args.text_encoder_2_path is not None
+
+ if args.transformer_ckpt_path is not None:
+ transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
+ transformer = transformer.to(dtype=dtype)
+ if not args.save_pipeline:
+ transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+
+ if args.vae_ckpt_path is not None:
+ vae = convert_vae(args.vae_ckpt_path)
+ if not args.save_pipeline:
+ vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+
+ if args.save_pipeline:
+ if args.transformer_type == "HYVideo-T/2-cfgdistill":
+ text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
+ text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
+ tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
+
+ pipe = HunyuanVideoPipeline(
+ transformer=transformer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ scheduler=scheduler,
+ )
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+ else:
+ text_encoder = LlavaForConditionalGeneration.from_pretrained(
+ args.text_encoder_path, torch_dtype=torch.float16
+ )
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
+ text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
+ tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
+ image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)
+
+ pipe = HunyuanVideoImageToVideoPipeline(
+ transformer=transformer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
diff --git a/scripts/convert_i2vgen_to_diffusers.py b/scripts/convert_i2vgen_to_diffusers.py
index b9e3ff2cd35c..643780caac2d 100644
--- a/scripts/convert_i2vgen_to_diffusers.py
+++ b/scripts/convert_i2vgen_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_ldm_original_checkpoint_to_diffusers.py b/scripts/convert_ldm_original_checkpoint_to_diffusers.py
index ada7dc6e2950..cdaf317af752 100644
--- a/scripts/convert_ldm_original_checkpoint_to_diffusers.py
+++ b/scripts/convert_ldm_original_checkpoint_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py
new file mode 100644
index 000000000000..2e966d5d110b
--- /dev/null
+++ b/scripts/convert_ltx_to_diffusers.py
@@ -0,0 +1,371 @@
+import argparse
+from pathlib import Path
+from typing import Any, Dict
+
+import torch
+from accelerate import init_empty_weights
+from safetensors.torch import load_file
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
+
+
+def remove_keys_(key: str, state_dict: Dict[str, Any]):
+ state_dict.pop(key)
+
+
+TOKENIZER_MAX_LENGTH = 128
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "patchify_proj": "proj_in",
+ "adaln_single": "time_embed",
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "vae": remove_keys_,
+}
+
+VAE_KEYS_RENAME_DICT = {
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0",
+ "up_blocks.2": "up_blocks.1.upsamplers.0",
+ "up_blocks.3": "up_blocks.1",
+ "up_blocks.4": "up_blocks.2.conv_in",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.conv_in",
+ "up_blocks.8": "up_blocks.3.upsamplers.0",
+ "up_blocks.9": "up_blocks.3",
+ # encoder
+ "down_blocks.0": "down_blocks.0",
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
+ "down_blocks.2": "down_blocks.0.conv_out",
+ "down_blocks.3": "down_blocks.1",
+ "down_blocks.4": "down_blocks.1.downsamplers.0",
+ "down_blocks.5": "down_blocks.1.conv_out",
+ "down_blocks.6": "down_blocks.2",
+ "down_blocks.7": "down_blocks.2.downsamplers.0",
+ "down_blocks.8": "down_blocks.3",
+ "down_blocks.9": "mid_block",
+ # common
+ "conv_shortcut": "conv_shortcut.conv",
+ "res_blocks": "resnets",
+ "norm3.norm": "norm3",
+ "per_channel_statistics.mean-of-means": "latents_mean",
+ "per_channel_statistics.std-of-means": "latents_std",
+}
+
+VAE_091_RENAME_DICT = {
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
+ "up_blocks.2": "up_blocks.0",
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
+ "up_blocks.4": "up_blocks.1",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
+ "up_blocks.8": "up_blocks.3",
+ # common
+ "last_time_embedder": "time_embedder",
+ "last_scale_shift_table": "scale_shift_table",
+}
+
+VAE_095_RENAME_DICT = {
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
+ "up_blocks.2": "up_blocks.0",
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
+ "up_blocks.4": "up_blocks.1",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
+ "up_blocks.8": "up_blocks.3",
+ # encoder
+ "down_blocks.0": "down_blocks.0",
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
+ "down_blocks.2": "down_blocks.1",
+ "down_blocks.3": "down_blocks.1.downsamplers.0",
+ "down_blocks.4": "down_blocks.2",
+ "down_blocks.5": "down_blocks.2.downsamplers.0",
+ "down_blocks.6": "down_blocks.3",
+ "down_blocks.7": "down_blocks.3.downsamplers.0",
+ "down_blocks.8": "mid_block",
+ # common
+ "last_time_embedder": "time_embedder",
+ "last_scale_shift_table": "scale_shift_table",
+}
+
+VAE_SPECIAL_KEYS_REMAP = {
+ "per_channel_statistics.channel": remove_keys_,
+ "per_channel_statistics.mean-of-means": remove_keys_,
+ "per_channel_statistics.mean-of-stds": remove_keys_,
+ "model.diffusion_model": remove_keys_,
+}
+
+
+def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
+ state_dict = saved_dict
+ if "model" in saved_dict.keys():
+ state_dict = state_dict["model"]
+ if "module" in saved_dict.keys():
+ state_dict = state_dict["module"]
+ if "state_dict" in saved_dict.keys():
+ state_dict = state_dict["state_dict"]
+ return state_dict
+
+
+def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def convert_transformer(
+ ckpt_path: str,
+ dtype: torch.dtype,
+ version: str = "0.9.0",
+):
+ PREFIX_KEY = "model.diffusion_model."
+
+ original_state_dict = get_state_dict(load_file(ckpt_path))
+ config = {}
+ if version == "0.9.5":
+ config["_use_causal_rope_fix"] = True
+ with init_empty_weights():
+ transformer = LTXVideoTransformer3DModel(**config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ if new_key.startswith(PREFIX_KEY):
+ new_key = key[len(PREFIX_KEY) :]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
+ PREFIX_KEY = "vae."
+
+ original_state_dict = get_state_dict(load_file(ckpt_path))
+ with init_empty_weights():
+ vae = AutoencoderKLLTXVideo(**config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ if new_key.startswith(PREFIX_KEY):
+ new_key = key[len(PREFIX_KEY) :]
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vae.load_state_dict(original_state_dict, strict=True, assign=True)
+ return vae
+
+
+def get_vae_config(version: str) -> Dict[str, Any]:
+ if version == "0.9.0":
+ config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 128,
+ "block_out_channels": (128, 256, 512, 512),
+ "down_block_types": (
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ ),
+ "decoder_block_out_channels": (128, 256, 512, 512),
+ "layers_per_block": (4, 3, 3, 3, 4),
+ "decoder_layers_per_block": (4, 3, 3, 3, 4),
+ "spatio_temporal_scaling": (True, True, True, False),
+ "decoder_spatio_temporal_scaling": (True, True, True, False),
+ "decoder_inject_noise": (False, False, False, False, False),
+ "downsample_type": ("conv", "conv", "conv", "conv"),
+ "upsample_residual": (False, False, False, False),
+ "upsample_factor": (1, 1, 1, 1),
+ "patch_size": 4,
+ "patch_size_t": 1,
+ "resnet_norm_eps": 1e-6,
+ "scaling_factor": 1.0,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ "timestep_conditioning": False,
+ }
+ elif version == "0.9.1":
+ config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 128,
+ "block_out_channels": (128, 256, 512, 512),
+ "down_block_types": (
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ ),
+ "decoder_block_out_channels": (256, 512, 1024),
+ "layers_per_block": (4, 3, 3, 3, 4),
+ "decoder_layers_per_block": (5, 6, 7, 8),
+ "spatio_temporal_scaling": (True, True, True, False),
+ "decoder_spatio_temporal_scaling": (True, True, True),
+ "decoder_inject_noise": (True, True, True, False),
+ "downsample_type": ("conv", "conv", "conv", "conv"),
+ "upsample_residual": (True, True, True),
+ "upsample_factor": (2, 2, 2),
+ "timestep_conditioning": True,
+ "patch_size": 4,
+ "patch_size_t": 1,
+ "resnet_norm_eps": 1e-6,
+ "scaling_factor": 1.0,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ }
+ VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
+ elif version == "0.9.5":
+ config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 128,
+ "block_out_channels": (128, 256, 512, 1024, 2048),
+ "down_block_types": (
+ "LTXVideo095DownBlock3D",
+ "LTXVideo095DownBlock3D",
+ "LTXVideo095DownBlock3D",
+ "LTXVideo095DownBlock3D",
+ ),
+ "decoder_block_out_channels": (256, 512, 1024),
+ "layers_per_block": (4, 6, 6, 2, 2),
+ "decoder_layers_per_block": (5, 5, 5, 5),
+ "spatio_temporal_scaling": (True, True, True, True),
+ "decoder_spatio_temporal_scaling": (True, True, True),
+ "decoder_inject_noise": (False, False, False, False),
+ "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
+ "upsample_residual": (True, True, True),
+ "upsample_factor": (2, 2, 2),
+ "timestep_conditioning": True,
+ "patch_size": 4,
+ "patch_size_t": 1,
+ "resnet_norm_eps": 1e-6,
+ "scaling_factor": 1.0,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ "spatial_compression_ratio": 32,
+ "temporal_compression_ratio": 8,
+ }
+ VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
+ return config
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
+ )
+ parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
+ parser.add_argument(
+ "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
+ )
+ parser.add_argument(
+ "--typecast_text_encoder",
+ action="store_true",
+ default=False,
+ help="Whether or not to apply fp16/bf16 precision to text_encoder",
+ )
+ parser.add_argument("--save_pipeline", action="store_true")
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+ parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
+ parser.add_argument(
+ "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
+ )
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+VARIANT_MAPPING = {
+ "fp32": None,
+ "fp16": "fp16",
+ "bf16": "bf16",
+}
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ dtype = DTYPE_MAPPING[args.dtype]
+ variant = VARIANT_MAPPING[args.dtype]
+ output_path = Path(args.output_path)
+
+ if args.save_pipeline:
+ assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
+
+ if args.transformer_ckpt_path is not None:
+ transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
+ if not args.save_pipeline:
+ transformer.save_pretrained(
+ output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
+ )
+
+ if args.vae_ckpt_path is not None:
+ config = get_vae_config(args.version)
+ vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
+ if not args.save_pipeline:
+ vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
+
+ if args.save_pipeline:
+ text_encoder_id = "google/t5-v1_1-xxl"
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+
+ if args.typecast_text_encoder:
+ text_encoder = text_encoder.to(dtype=dtype)
+
+ # Apparently, the conversion does not work anymore without this :shrug:
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ if args.version == "0.9.5":
+ scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
+ else:
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ use_dynamic_shifting=True,
+ base_shift=0.95,
+ max_shift=2.05,
+ base_image_seq_len=1024,
+ max_image_seq_len=4096,
+ shift_terminal=0.1,
+ )
+
+ pipe = LTXPipeline(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ )
+
+ pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
diff --git a/scripts/convert_lumina_to_diffusers.py b/scripts/convert_lumina_to_diffusers.py
index a12625d1376f..c14aad3c6bf2 100644
--- a/scripts/convert_lumina_to_diffusers.py
+++ b/scripts/convert_lumina_to_diffusers.py
@@ -5,7 +5,7 @@
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
-from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
def main(args):
@@ -115,7 +115,7 @@ def main(args):
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
- pipeline = LuminaText2ImgPipeline(
+ pipeline = LuminaPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)
pipeline.save_pretrained(args.dump_path)
diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py
new file mode 100644
index 000000000000..9727deeb6b0c
--- /dev/null
+++ b/scripts/convert_mochi_to_diffusers.py
@@ -0,0 +1,461 @@
+import argparse
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+from safetensors.torch import load_file
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
+
+TOKENIZER_MAX_LENGTH = 256
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
+parser.add_argument("--vae_encoder_checkpoint_path", default=None, type=str)
+parser.add_argument("--vae_decoder_checkpoint_path", default=None, type=str)
+parser.add_argument("--output_path", required=True, type=str)
+parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
+parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
+parser.add_argument("--dtype", type=str, default=None)
+
+args = parser.parse_args()
+
+
+# This is specific to `AdaLayerNormContinuous`:
+# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
+def swap_scale_shift(weight, dim):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def swap_proj_gate(weight):
+ proj, gate = weight.chunk(2, dim=0)
+ new_weight = torch.cat([gate, proj], dim=0)
+ return new_weight
+
+
+def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
+ original_state_dict = load_file(ckpt_path, device="cpu")
+ new_state_dict = {}
+
+ # Convert patch_embed
+ new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
+ new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
+
+ # Convert time_embed
+ new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
+ new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
+ new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
+ new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
+ new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
+ new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
+ new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
+ new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
+ new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
+ new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
+ new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
+ new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
+
+ # Convert transformer blocks
+ num_layers = 48
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ old_prefix = f"blocks.{i}."
+
+ # norm1
+ new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
+ new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
+ if i < num_layers - 1:
+ new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
+ old_prefix + "mod_y.weight"
+ )
+ new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
+ old_prefix + "mod_y.bias"
+ )
+ else:
+ new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
+ old_prefix + "mod_y.weight"
+ )
+ new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
+ old_prefix + "mod_y.bias"
+ )
+
+ # Visual attention
+ qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
+ old_prefix + "attn.q_norm_x.weight"
+ )
+ new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
+ old_prefix + "attn.k_norm_x.weight"
+ )
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
+ old_prefix + "attn.proj_x.weight"
+ )
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
+
+ # Context attention
+ qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
+ new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
+ new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
+ new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
+ old_prefix + "attn.q_norm_y.weight"
+ )
+ new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
+ old_prefix + "attn.k_norm_y.weight"
+ )
+ if i < num_layers - 1:
+ new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
+ old_prefix + "attn.proj_y.weight"
+ )
+ new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
+ old_prefix + "attn.proj_y.bias"
+ )
+
+ # MLP
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
+ original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
+ )
+ new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
+ if i < num_layers - 1:
+ new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
+ original_state_dict.pop(old_prefix + "mlp_y.w1.weight")
+ )
+ new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
+ old_prefix + "mlp_y.w2.weight"
+ )
+
+ # Output layers
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.mod.weight"), dim=0
+ )
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
+ new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
+ new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
+
+ new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
+
+ print("Remaining Keys:", original_state_dict.keys())
+
+ return new_state_dict
+
+
+def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path):
+ encoder_state_dict = load_file(encoder_ckpt_path, device="cpu")
+ decoder_state_dict = load_file(decoder_ckpt_path, device="cpu")
+ new_state_dict = {}
+
+ # ==== Decoder =====
+ prefix = "decoder."
+
+ # Convert conv_in
+ new_state_dict[f"{prefix}conv_in.weight"] = decoder_state_dict.pop("blocks.0.0.weight")
+ new_state_dict[f"{prefix}conv_in.bias"] = decoder_state_dict.pop("blocks.0.0.bias")
+
+ # Convert block_in (MochiMidBlock3D)
+ for i in range(3): # layers_per_block[-1] = 3
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
+ f"blocks.0.{i+1}.stack.0.weight"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
+ f"blocks.0.{i+1}.stack.0.bias"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
+ f"blocks.0.{i+1}.stack.2.weight"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
+ f"blocks.0.{i+1}.stack.2.bias"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
+ f"blocks.0.{i+1}.stack.3.weight"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
+ f"blocks.0.{i+1}.stack.3.bias"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
+ f"blocks.0.{i+1}.stack.5.weight"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
+ f"blocks.0.{i+1}.stack.5.bias"
+ )
+
+ # Convert up_blocks (MochiUpBlock3D)
+ down_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
+ for block in range(3):
+ for i in range(down_block_layers[block]):
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.blocks.{i}.stack.0.weight"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.blocks.{i}.stack.0.bias"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.blocks.{i}.stack.2.weight"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.blocks.{i}.stack.2.bias"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.blocks.{i}.stack.3.weight"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.blocks.{i}.stack.3.bias"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.blocks.{i}.stack.5.weight"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.blocks.{i}.stack.5.bias"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
+ f"blocks.{block+1}.proj.weight"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
+
+ # Convert block_out (MochiMidBlock3D)
+ for i in range(3): # layers_per_block[0] = 3
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
+ f"blocks.4.{i}.stack.0.weight"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
+ f"blocks.4.{i}.stack.0.bias"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
+ f"blocks.4.{i}.stack.2.weight"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
+ f"blocks.4.{i}.stack.2.bias"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
+ f"blocks.4.{i}.stack.3.weight"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
+ f"blocks.4.{i}.stack.3.bias"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
+ f"blocks.4.{i}.stack.5.weight"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
+ f"blocks.4.{i}.stack.5.bias"
+ )
+
+ # Convert proj_out (Conv1x1 ~= nn.Linear)
+ new_state_dict[f"{prefix}proj_out.weight"] = decoder_state_dict.pop("output_proj.weight")
+ new_state_dict[f"{prefix}proj_out.bias"] = decoder_state_dict.pop("output_proj.bias")
+
+ print("Remaining Decoder Keys:", decoder_state_dict.keys())
+
+ # ==== Encoder =====
+ prefix = "encoder."
+
+ new_state_dict[f"{prefix}proj_in.weight"] = encoder_state_dict.pop("layers.0.weight")
+ new_state_dict[f"{prefix}proj_in.bias"] = encoder_state_dict.pop("layers.0.bias")
+
+ # Convert block_in (MochiMidBlock3D)
+ for i in range(3): # layers_per_block[0] = 3
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
+ f"layers.{i+1}.stack.0.weight"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
+ f"layers.{i+1}.stack.0.bias"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
+ f"layers.{i+1}.stack.2.weight"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
+ f"layers.{i+1}.stack.2.bias"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
+ f"layers.{i+1}.stack.3.weight"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
+ f"layers.{i+1}.stack.3.bias"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
+ f"layers.{i+1}.stack.5.weight"
+ )
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
+ f"layers.{i+1}.stack.5.bias"
+ )
+
+ # Convert down_blocks (MochiDownBlock3D)
+ down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
+ for block in range(3):
+ new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.0.weight"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.0.bias"
+ )
+
+ for i in range(down_block_layers[block]):
+ # Convert resnets
+ new_state_dict[
+ f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
+ ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.stack.0.bias"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.stack.2.weight"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.stack.2.bias"
+ )
+ new_state_dict[
+ f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
+ ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.stack.3.bias"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.stack.5.weight"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.stack.5.bias"
+ )
+
+ # Convert attentions
+ qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
+ f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
+ )
+
+ # Convert block_out (MochiMidBlock3D)
+ for i in range(3): # layers_per_block[-1] = 3
+ # Convert resnets
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
+ f"layers.{i+7}.stack.0.weight"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
+ f"layers.{i+7}.stack.0.bias"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
+ f"layers.{i+7}.stack.2.weight"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
+ f"layers.{i+7}.stack.2.bias"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
+ f"layers.{i+7}.stack.3.weight"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
+ f"layers.{i+7}.stack.3.bias"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
+ f"layers.{i+7}.stack.5.weight"
+ )
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
+ f"layers.{i+7}.stack.5.bias"
+ )
+
+ # Convert attentions
+ qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
+ f"layers.{i+7}.attn_block.attn.out.weight"
+ )
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
+ f"layers.{i+7}.attn_block.attn.out.bias"
+ )
+ new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
+ f"layers.{i+7}.attn_block.norm.weight"
+ )
+ new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
+ f"layers.{i+7}.attn_block.norm.bias"
+ )
+
+ # Convert output layers
+ new_state_dict[f"{prefix}norm_out.norm_layer.weight"] = encoder_state_dict.pop("output_norm.weight")
+ new_state_dict[f"{prefix}norm_out.norm_layer.bias"] = encoder_state_dict.pop("output_norm.bias")
+ new_state_dict[f"{prefix}proj_out.weight"] = encoder_state_dict.pop("output_proj.weight")
+
+ print("Remaining Encoder Keys:", encoder_state_dict.keys())
+
+ return new_state_dict
+
+
+def main(args):
+ if args.dtype is None:
+ dtype = None
+ if args.dtype == "fp16":
+ dtype = torch.float16
+ elif args.dtype == "bf16":
+ dtype = torch.bfloat16
+ elif args.dtype == "fp32":
+ dtype = torch.float32
+ else:
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
+
+ transformer = None
+ vae = None
+
+ if args.transformer_checkpoint_path is not None:
+ converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
+ args.transformer_checkpoint_path
+ )
+ transformer = MochiTransformer3DModel()
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+ if dtype is not None:
+ transformer = transformer.to(dtype=dtype)
+
+ if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None:
+ vae = AutoencoderKLMochi(latent_channels=12, out_channels=3)
+ converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers(
+ args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path
+ )
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ if dtype is not None:
+ vae = vae.to(dtype=dtype)
+
+ text_encoder_id = "google/t5-v1_1-xxl"
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+
+ # Apparently, the conversion does not work anymore without this :shrug:
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ pipe = MochiPipeline(
+ scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True),
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ )
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_ms_text_to_video_to_diffusers.py b/scripts/convert_ms_text_to_video_to_diffusers.py
index 0251ab680d59..e150a491a0b0 100644
--- a/scripts/convert_ms_text_to_video_to_diffusers.py
+++ b/scripts/convert_ms_text_to_video_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py b/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
index 2d67123d9ad7..bcab90e2a3db 100644
--- a/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
+++ b/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py
new file mode 100644
index 000000000000..96bc935633f0
--- /dev/null
+++ b/scripts/convert_omnigen_to_diffusers.py
@@ -0,0 +1,203 @@
+import argparse
+import os
+
+import torch
+from huggingface_hub import snapshot_download
+from safetensors.torch import load_file
+from transformers import AutoTokenizer
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
+
+
+def main(args):
+ # checkpoint from https://huggingface.co/Shitao/OmniGen-v1
+
+ if not os.path.exists(args.origin_ckpt_path):
+ print("Model not found, downloading...")
+ cache_folder = os.getenv("HF_HUB_CACHE")
+ args.origin_ckpt_path = snapshot_download(
+ repo_id=args.origin_ckpt_path,
+ cache_dir=cache_folder,
+ ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
+ )
+ print(f"Downloaded model to {args.origin_ckpt_path}")
+
+ ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
+ ckpt = load_file(ckpt, device="cpu")
+
+ mapping_dict = {
+ "pos_embed": "patch_embedding.pos_embed",
+ "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
+ "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
+ "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
+ "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
+ "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
+ "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
+ "final_layer.linear.weight": "proj_out.weight",
+ "final_layer.linear.bias": "proj_out.bias",
+ "time_token.mlp.0.weight": "time_token.linear_1.weight",
+ "time_token.mlp.0.bias": "time_token.linear_1.bias",
+ "time_token.mlp.2.weight": "time_token.linear_2.weight",
+ "time_token.mlp.2.bias": "time_token.linear_2.bias",
+ "t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
+ "t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
+ "t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
+ "t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
+ "llm.embed_tokens.weight": "embed_tokens.weight",
+ }
+
+ converted_state_dict = {}
+ for k, v in ckpt.items():
+ if k in mapping_dict:
+ converted_state_dict[mapping_dict[k]] = v
+ elif "qkv" in k:
+ to_q, to_k, to_v = v.chunk(3)
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
+ elif "o_proj" in k:
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
+ else:
+ converted_state_dict[k[4:]] = v
+
+ transformer = OmniGenTransformer2DModel(
+ rope_scaling={
+ "long_factor": [
+ 1.0299999713897705,
+ 1.0499999523162842,
+ 1.0499999523162842,
+ 1.0799999237060547,
+ 1.2299998998641968,
+ 1.2299998998641968,
+ 1.2999999523162842,
+ 1.4499999284744263,
+ 1.5999999046325684,
+ 1.6499998569488525,
+ 1.8999998569488525,
+ 2.859999895095825,
+ 3.68999981880188,
+ 5.419999599456787,
+ 5.489999771118164,
+ 5.489999771118164,
+ 9.09000015258789,
+ 11.579999923706055,
+ 15.65999984741211,
+ 15.769999504089355,
+ 15.789999961853027,
+ 18.360000610351562,
+ 21.989999771118164,
+ 23.079999923706055,
+ 30.009998321533203,
+ 32.35000228881836,
+ 32.590003967285156,
+ 35.56000518798828,
+ 39.95000457763672,
+ 53.840003967285156,
+ 56.20000457763672,
+ 57.95000457763672,
+ 59.29000473022461,
+ 59.77000427246094,
+ 59.920005798339844,
+ 61.190006256103516,
+ 61.96000671386719,
+ 62.50000762939453,
+ 63.3700065612793,
+ 63.48000717163086,
+ 63.48000717163086,
+ 63.66000747680664,
+ 63.850006103515625,
+ 64.08000946044922,
+ 64.760009765625,
+ 64.80001068115234,
+ 64.81001281738281,
+ 64.81001281738281,
+ ],
+ "short_factor": [
+ 1.05,
+ 1.05,
+ 1.05,
+ 1.1,
+ 1.1,
+ 1.1,
+ 1.2500000000000002,
+ 1.2500000000000002,
+ 1.4000000000000004,
+ 1.4500000000000004,
+ 1.5500000000000005,
+ 1.8500000000000008,
+ 1.9000000000000008,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.1000000000000005,
+ 2.1000000000000005,
+ 2.2,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3999999999999995,
+ 2.3999999999999995,
+ 2.6499999999999986,
+ 2.6999999999999984,
+ 2.8999999999999977,
+ 2.9499999999999975,
+ 3.049999999999997,
+ 3.049999999999997,
+ 3.049999999999997,
+ ],
+ "type": "su",
+ },
+ patch_size=2,
+ in_channels=4,
+ pos_embed_max_size=192,
+ )
+ transformer.load_state_dict(converted_state_dict, strict=True)
+ transformer.to(torch.bfloat16)
+
+ num_model_params = sum(p.numel() for p in transformer.parameters())
+ print(f"Total number of transformer parameters: {num_model_params}")
+
+ scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
+
+ vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
+
+ tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
+
+ pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
+ pipeline.save_pretrained(args.dump_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--origin_ckpt_path",
+ default="Shitao/OmniGen-v1",
+ type=str,
+ required=False,
+ help="Path to the checkpoint to convert.",
+ )
+
+ parser.add_argument(
+ "--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/scripts/convert_original_audioldm2_to_diffusers.py b/scripts/convert_original_audioldm2_to_diffusers.py
index ea9c02d53815..1dc7d739ea76 100644
--- a/scripts/convert_original_audioldm2_to_diffusers.py
+++ b/scripts/convert_original_audioldm2_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py
index 797d19826091..4f8e4f8f9f80 100644
--- a/scripts/convert_original_audioldm_to_diffusers.py
+++ b/scripts/convert_original_audioldm_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_original_controlnet_to_diffusers.py b/scripts/convert_original_controlnet_to_diffusers.py
index 92aad4f09e70..4c6fe90cb09f 100644
--- a/scripts/convert_original_controlnet_to_diffusers.py
+++ b/scripts/convert_original_controlnet_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_original_musicldm_to_diffusers.py b/scripts/convert_original_musicldm_to_diffusers.py
index 6db9dbdfdb74..61e5d16eea9e 100644
--- a/scripts/convert_original_musicldm_to_diffusers.py
+++ b/scripts/convert_original_musicldm_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py
index 7e7925b0a412..59eeeec24c79 100644
--- a/scripts/convert_original_stable_diffusion_to_diffusers.py
+++ b/scripts/convert_original_stable_diffusion_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_original_t2i_adapter.py b/scripts/convert_original_t2i_adapter.py
index 95c8817b508d..e23a2431ce9e 100644
--- a/scripts/convert_original_t2i_adapter.py
+++ b/scripts/convert_original_t2i_adapter.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py
new file mode 100644
index 000000000000..1c40072177c6
--- /dev/null
+++ b/scripts/convert_sana_to_diffusers.py
@@ -0,0 +1,456 @@
+#!/usr/bin/env python
+from __future__ import annotations
+
+import argparse
+import os
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download, snapshot_download
+from termcolor import colored
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from diffusers import (
+ AutoencoderDC,
+ DPMSolverMultistepScheduler,
+ FlowMatchEulerDiscreteScheduler,
+ SanaPipeline,
+ SanaSprintPipeline,
+ SanaTransformer2DModel,
+ SCMScheduler,
+)
+from diffusers.models.modeling_utils import load_model_dict_into_meta
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+ckpt_ids = [
+ "Efficient-Large-Model/Sana_Sprint_0.6B_1024px/checkpoints/Sana_Sprint_0.6B_1024px.pth"
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
+ "Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
+ "Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
+ "Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
+ "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
+ "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
+ "Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
+ "Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
+ "Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
+ "Efficient-Large-Model/Sana_1600M_512px/checkpoints/Sana_1600M_512px.pth",
+ "Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth",
+ "Efficient-Large-Model/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth",
+]
+# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py
+
+
+def main(args):
+ cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
+
+ if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
+ ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
+ snapshot_download(
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
+ cache_dir=cache_dir_path,
+ repo_type="model",
+ )
+ file_path = hf_hub_download(
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
+ filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
+ cache_dir=cache_dir_path,
+ repo_type="model",
+ )
+ else:
+ file_path = args.orig_ckpt_path
+
+ print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
+ all_state_dict = torch.load(file_path, weights_only=True)
+ state_dict = all_state_dict.pop("state_dict")
+ converted_state_dict = {}
+
+ # Patch embeddings.
+ converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
+
+ # Caption projection.
+ converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
+ converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
+ converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
+ converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
+
+ # Handle different time embedding structure based on model type
+
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
+ # For Sana Sprint, the time embedding structure is different
+ converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
+
+ # Guidance embedder for Sana Sprint
+ converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
+ "cfg_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
+ "cfg_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
+ else:
+ # Original Sana time embedding structure
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
+ "t_embedder.mlp.0.bias"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
+ "t_embedder.mlp.2.bias"
+ )
+
+ # Shared norm.
+ converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
+ converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
+
+ # y norm
+ converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
+
+ # scheduler
+ if args.image_size == 4096:
+ flow_shift = 6.0
+ else:
+ flow_shift = 3.0
+
+ # model config
+ if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
+ layer_num = 20
+ elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
+ layer_num = 28
+ elif args.model_type == "SanaMS_4800M_P1_D60":
+ layer_num = 60
+ else:
+ raise ValueError(f"{args.model_type} is not supported.")
+ # Positional embedding interpolation scale.
+ interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
+ qk_norm = (
+ "rms_norm_across_heads"
+ if args.model_type
+ in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
+ else None
+ )
+
+ for depth in range(layer_num):
+ # Transformer blocks.
+ converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
+ f"blocks.{depth}.scale_shift_table"
+ )
+
+ # Linear Attention is all you need 🤘
+ # Self attention.
+ q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ if qk_norm is not None:
+ # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.k_norm.weight"
+ )
+ # Projection.
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.bias"
+ )
+
+ # Feed-forward.
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.inverted_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.inverted_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.depth_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.depth_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.point_conv.conv.weight"
+ )
+
+ # Cross-attention.
+ q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
+ q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
+ k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
+ if qk_norm is not None:
+ # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.k_norm.weight"
+ )
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.bias"
+ )
+
+ # Final block.
+ converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
+
+ # Transformer
+ with CTX():
+ transformer_kwargs = {
+ "in_channels": 32,
+ "out_channels": 32,
+ "num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
+ "attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
+ "num_layers": model_kwargs[args.model_type]["num_layers"],
+ "num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
+ "cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
+ "cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
+ "caption_channels": 2304,
+ "mlp_ratio": 2.5,
+ "attention_bias": False,
+ "sample_size": args.image_size // 32,
+ "patch_size": 1,
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "interpolation_scale": interpolation_scale[args.image_size],
+ }
+
+ # Add qk_norm parameter for Sana Sprint
+ if args.model_type in [
+ "SanaMS1.5_1600M_P1_D20",
+ "SanaMS1.5_4800M_P1_D60",
+ "SanaSprint_600M_P1_D28",
+ "SanaSprint_1600M_P1_D20",
+ ]:
+ transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
+ transformer_kwargs["guidance_embeds"] = True
+
+ transformer = SanaTransformer2DModel(**transformer_kwargs)
+
+ if is_accelerate_available():
+ load_model_dict_into_meta(transformer, converted_state_dict)
+ else:
+ transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
+
+ try:
+ state_dict.pop("y_embedder.y_embedding")
+ state_dict.pop("pos_embed")
+ state_dict.pop("logvar_linear.weight")
+ state_dict.pop("logvar_linear.bias")
+ except KeyError:
+ print("y_embedder.y_embedding or pos_embed not found in the state_dict")
+
+ assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
+
+ num_model_params = sum(p.numel() for p in transformer.parameters())
+ print(f"Total number of transformer parameters: {num_model_params}")
+
+ transformer = transformer.to(weight_dtype)
+
+ if not args.save_full_pipeline:
+ print(
+ colored(
+ f"Only saving transformer model of {args.model_type}. "
+ f"Set --save_full_pipeline to save the whole Pipeline",
+ "green",
+ attrs=["bold"],
+ )
+ )
+ transformer.save_pretrained(
+ os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
+ )
+ else:
+ print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
+ # VAE
+ ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
+
+ # Text Encoder
+ text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
+ tokenizer.padding_side = "right"
+ text_encoder = AutoModelForCausalLM.from_pretrained(
+ text_encoder_model_path, torch_dtype=torch.bfloat16
+ ).get_decoder()
+
+ # Choose the appropriate pipeline and scheduler based on model type
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
+ # Force SCM Scheduler for Sana Sprint regardless of scheduler_type
+ if args.scheduler_type != "scm":
+ print(
+ colored(
+ f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
+ "yellow",
+ attrs=["bold"],
+ )
+ )
+
+ # SCM Scheduler for Sana Sprint
+ scheduler_config = {
+ "prediction_type": "trigflow",
+ "sigma_data": 0.5,
+ }
+ scheduler = SCMScheduler(**scheduler_config)
+ pipe = SanaSprintPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ transformer=transformer,
+ vae=ae,
+ scheduler=scheduler,
+ )
+ else:
+ # Original Sana scheduler
+ if args.scheduler_type == "flow-dpm_solver":
+ scheduler = DPMSolverMultistepScheduler(
+ flow_shift=flow_shift,
+ use_flow_sigmas=True,
+ prediction_type="flow_prediction",
+ )
+ elif args.scheduler_type == "flow-euler":
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
+ else:
+ raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
+
+ pipe = SanaPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ transformer=transformer,
+ vae=ae,
+ scheduler=scheduler,
+ )
+
+ pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--image_size",
+ default=1024,
+ type=int,
+ choices=[512, 1024, 2048, 4096],
+ required=False,
+ help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
+ )
+ parser.add_argument(
+ "--model_type",
+ default="SanaMS_1600M_P1_D20",
+ type=str,
+ choices=[
+ "SanaMS_1600M_P1_D20",
+ "SanaMS_600M_P1_D28",
+ "SanaMS1.5_1600M_P1_D20",
+ "SanaMS1.5_4800M_P1_D60",
+ "SanaSprint_1600M_P1_D20",
+ "SanaSprint_600M_P1_D28",
+ ],
+ )
+ parser.add_argument(
+ "--scheduler_type",
+ default="flow-dpm_solver",
+ type=str,
+ choices=["flow-dpm_solver", "flow-euler", "scm"],
+ help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
+ parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
+ parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
+
+ args = parser.parse_args()
+
+ model_kwargs = {
+ "SanaMS_1600M_P1_D20": {
+ "num_attention_heads": 70,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "num_layers": 20,
+ },
+ "SanaMS_600M_P1_D28": {
+ "num_attention_heads": 36,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 16,
+ "cross_attention_head_dim": 72,
+ "cross_attention_dim": 1152,
+ "num_layers": 28,
+ },
+ "SanaMS1.5_1600M_P1_D20": {
+ "num_attention_heads": 70,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "num_layers": 20,
+ },
+ "SanaMS1.5_4800M_P1_D60": {
+ "num_attention_heads": 70,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "num_layers": 60,
+ },
+ "SanaSprint_600M_P1_D28": {
+ "num_attention_heads": 36,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 16,
+ "cross_attention_head_dim": 72,
+ "cross_attention_dim": 1152,
+ "num_layers": 28,
+ },
+ "SanaSprint_1600M_P1_D20": {
+ "num_attention_heads": 70,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "num_layers": 20,
+ },
+ }
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ weight_dtype = DTYPE_MAPPING[args.dtype]
+
+ main(args)
diff --git a/scripts/convert_sd3_controlnet_to_diffusers.py b/scripts/convert_sd3_controlnet_to_diffusers.py
new file mode 100644
index 000000000000..171f40a7aa06
--- /dev/null
+++ b/scripts/convert_sd3_controlnet_to_diffusers.py
@@ -0,0 +1,185 @@
+"""
+A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
+
+Example:
+ Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
+ ```bash
+ python scripts/convert_sd3_controlnet_to_diffusers.py \
+ --checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \
+ --output_path "output/sd35-controlnet-canny" \
+ --dtype "fp16" # optional, defaults to fp32
+ ```
+
+ Or download and convert from HuggingFace repository:
+ ```bash
+ python scripts/convert_sd3_controlnet_to_diffusers.py \
+ --original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \
+ --filename "sd3.5_large_controlnet_canny.safetensors" \
+ --output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \
+ --dtype "fp32" # optional, defaults to fp32
+ ```
+
+Note:
+ The script supports the following ControlNet types from SD3.5:
+ - Canny edge detection
+ - Depth estimation
+ - Blur detection
+
+ The checkpoint files can be downloaded from:
+ https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
+"""
+
+import argparse
+
+import safetensors.torch
+import torch
+from huggingface_hub import hf_hub_download
+
+from diffusers import SD3ControlNetModel
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
+parser.add_argument(
+ "--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint"
+)
+parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo")
+parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
+parser.add_argument(
+ "--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)"
+)
+
+args = parser.parse_args()
+
+
+def load_original_checkpoint(args):
+ if args.original_state_dict_repo_id is not None:
+ if args.filename is None:
+ raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified")
+ print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
+ elif args.checkpoint_path is not None:
+ print(f"Loading checkpoint from local path: {args.checkpoint_path}")
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict):
+ converted_state_dict = {}
+
+ # Direct mappings for controlnet blocks
+ for i in range(19): # 19 controlnet blocks
+ converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"]
+ converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"]
+
+ # Positional embeddings
+ converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"]
+ converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"]
+
+ # Time and text embeddings
+ time_text_mappings = {
+ "time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight",
+ "time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias",
+ "time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight",
+ "time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias",
+ "time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight",
+ "time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias",
+ "time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight",
+ "time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias",
+ }
+
+ for new_key, old_key in time_text_mappings.items():
+ if old_key in original_state_dict:
+ converted_state_dict[new_key] = original_state_dict[old_key]
+
+ # Transformer blocks
+ for i in range(19):
+ # Split QKV into separate Q, K, V
+ qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"]
+ qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"]
+ q, k, v = torch.chunk(qkv_weight, 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
+
+ block_mappings = {
+ f"transformer_blocks.{i}.attn.to_q.weight": q,
+ f"transformer_blocks.{i}.attn.to_q.bias": q_bias,
+ f"transformer_blocks.{i}.attn.to_k.weight": k,
+ f"transformer_blocks.{i}.attn.to_k.bias": k_bias,
+ f"transformer_blocks.{i}.attn.to_v.weight": v,
+ f"transformer_blocks.{i}.attn.to_v.bias": v_bias,
+ # Output projections
+ f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[
+ f"transformer_blocks.{i}.attn.proj.weight"
+ ],
+ f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[
+ f"transformer_blocks.{i}.attn.proj.bias"
+ ],
+ # Feed forward
+ f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[
+ f"transformer_blocks.{i}.mlp.fc1.weight"
+ ],
+ f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"],
+ f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"],
+ f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"],
+ # Norms
+ f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[
+ f"transformer_blocks.{i}.adaLN_modulation.1.weight"
+ ],
+ f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[
+ f"transformer_blocks.{i}.adaLN_modulation.1.bias"
+ ],
+ }
+ converted_state_dict.update(block_mappings)
+
+ return converted_state_dict
+
+
+def main(args):
+ original_ckpt = load_original_checkpoint(args)
+ original_dtype = next(iter(original_ckpt.values())).dtype
+
+ # Initialize dtype with fp32 as default
+ if args.dtype == "fp16":
+ dtype = torch.float16
+ elif args.dtype == "bf16":
+ dtype = torch.bfloat16
+ elif args.dtype == "fp32":
+ dtype = torch.float32
+ else:
+ raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")
+
+ if dtype != original_dtype:
+ print(
+ f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution."
+ )
+
+ converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt)
+
+ controlnet = SD3ControlNetModel(
+ patch_size=2,
+ in_channels=16,
+ num_layers=19,
+ attention_head_dim=64,
+ num_attention_heads=38,
+ joint_attention_dim=None,
+ caption_projection_dim=2048,
+ pooled_projection_dim=2048,
+ out_channels=16,
+ pos_embed_max_size=None,
+ pos_embed_type=None,
+ use_pos_embed=False,
+ force_zeros_for_pooled_projection=False,
+ )
+
+ controlnet.load_state_dict(converted_controlnet_state_dict, strict=True)
+
+ print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.")
+ controlnet.to(dtype).save_pretrained(args.output_path)
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py
index 1f9c434b39d0..0a3569efeab0 100644
--- a/scripts/convert_sd3_to_diffusers.py
+++ b/scripts/convert_sd3_to_diffusers.py
@@ -11,7 +11,7 @@
from diffusers.utils.import_utils import is_accelerate_available
-CTX = init_empty_weights if is_accelerate_available else nullcontext
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str)
diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py
index 41e2e0191209..ce68bb4c2e8c 100644
--- a/scripts/convert_versatile_diffusion_to_diffusers.py
+++ b/scripts/convert_versatile_diffusion_to_diffusers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py
new file mode 100644
index 000000000000..0b2fa872487e
--- /dev/null
+++ b/scripts/convert_wan_to_diffusers.py
@@ -0,0 +1,423 @@
+import argparse
+import pathlib
+from typing import Any, Dict
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download, snapshot_download
+from safetensors.torch import load_file
+from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ UniPCMultistepScheduler,
+ WanImageToVideoPipeline,
+ WanPipeline,
+ WanTransformer3DModel,
+)
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ # For the I2V model
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+
+
+def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def load_sharded_safetensors(dir: pathlib.Path):
+ file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
+ state_dict = {}
+ for path in file_paths:
+ state_dict.update(load_file(path))
+ return state_dict
+
+
+def get_transformer_config(model_type: str) -> Dict[str, Any]:
+ if model_type == "Wan-T2V-1.3B":
+ config = {
+ "model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 12,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "Wan-T2V-14B":
+ config = {
+ "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "Wan-I2V-14B-480p":
+ config = {
+ "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
+ "diffusers_config": {
+ "image_dim": 1280,
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "Wan-I2V-14B-720p":
+ config = {
+ "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
+ "diffusers_config": {
+ "image_dim": 1280,
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ return config
+
+
+def convert_transformer(model_type: str):
+ config = get_transformer_config(model_type)
+ diffusers_config = config["diffusers_config"]
+ model_id = config["model_id"]
+ model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
+
+ original_state_dict = load_sharded_safetensors(model_dir)
+
+ with init_empty_weights():
+ transformer = WanTransformer3DModel.from_config(diffusers_config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def convert_vae():
+ vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
+ old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
+ new_state_dict = {}
+
+ # Create mappings for specific components
+ middle_key_mapping = {
+ # Encoder middle block
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ # Decoder middle block
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ }
+
+ # Create a mapping for attention blocks
+ attention_mapping = {
+ # Encoder middle attention
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
+ # Decoder middle attention
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
+ }
+
+ # Create a mapping for the head components
+ head_mapping = {
+ # Encoder head
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
+ "encoder.head.2.bias": "encoder.conv_out.bias",
+ "encoder.head.2.weight": "encoder.conv_out.weight",
+ # Decoder head
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
+ "decoder.head.2.bias": "decoder.conv_out.bias",
+ "decoder.head.2.weight": "decoder.conv_out.weight",
+ }
+
+ # Create a mapping for the quant components
+ quant_mapping = {
+ "conv1.weight": "quant_conv.weight",
+ "conv1.bias": "quant_conv.bias",
+ "conv2.weight": "post_quant_conv.weight",
+ "conv2.bias": "post_quant_conv.bias",
+ }
+
+ # Process each key in the state dict
+ for key, value in old_state_dict.items():
+ # Handle middle block keys using the mapping
+ if key in middle_key_mapping:
+ new_key = middle_key_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle attention blocks using the mapping
+ elif key in attention_mapping:
+ new_key = attention_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle head keys using the mapping
+ elif key in head_mapping:
+ new_key = head_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle quant keys using the mapping
+ elif key in quant_mapping:
+ new_key = quant_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle encoder conv1
+ elif key == "encoder.conv1.weight":
+ new_state_dict["encoder.conv_in.weight"] = value
+ elif key == "encoder.conv1.bias":
+ new_state_dict["encoder.conv_in.bias"] = value
+ # Handle decoder conv1
+ elif key == "decoder.conv1.weight":
+ new_state_dict["decoder.conv_in.weight"] = value
+ elif key == "decoder.conv1.bias":
+ new_state_dict["decoder.conv_in.bias"] = value
+ # Handle encoder downsamples
+ elif key.startswith("encoder.downsamples."):
+ # Convert to down_blocks
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
+
+ # Convert residual block naming but keep the original structure
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+
+ new_state_dict[new_key] = value
+
+ # Handle decoder upsamples
+ elif key.startswith("decoder.upsamples."):
+ # Convert to up_blocks
+ parts = key.split(".")
+ block_idx = int(parts[2])
+
+ # Group residual blocks
+ if "residual" in key:
+ if block_idx in [0, 1, 2]:
+ new_block_idx = 0
+ resnet_idx = block_idx
+ elif block_idx in [4, 5, 6]:
+ new_block_idx = 1
+ resnet_idx = block_idx - 4
+ elif block_idx in [8, 9, 10]:
+ new_block_idx = 2
+ resnet_idx = block_idx - 8
+ elif block_idx in [12, 13, 14]:
+ new_block_idx = 3
+ resnet_idx = block_idx - 12
+ else:
+ # Keep as is for other blocks
+ new_state_dict[key] = value
+ continue
+
+ # Convert residual block naming
+ if ".residual.0.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
+ elif ".residual.2.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
+ elif ".residual.2.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
+ elif ".residual.3.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
+ elif ".residual.6.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
+ elif ".residual.6.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
+ else:
+ new_key = key
+
+ new_state_dict[new_key] = value
+
+ # Handle shortcut connections
+ elif ".shortcut." in key:
+ if block_idx == 4:
+ new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
+ new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
+
+ new_state_dict[new_key] = value
+
+ # Handle upsamplers
+ elif ".resample." in key or ".time_conv." in key:
+ if block_idx == 3:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
+ elif block_idx == 7:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
+ elif block_idx == 11:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+
+ new_state_dict[new_key] = value
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_state_dict[new_key] = value
+ else:
+ # Keep other keys unchanged
+ new_state_dict[key] = value
+
+ with init_empty_weights():
+ vae = AutoencoderKLWan()
+ vae.load_state_dict(new_state_dict, strict=True, assign=True)
+ return vae
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default=None)
+ parser.add_argument("--output_path", type=str, required=True)
+ parser.add_argument("--dtype", default="fp32")
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ dtype = DTYPE_MAPPING[args.dtype]
+
+ transformer = convert_transformer(args.model_type).to(dtype=dtype)
+ vae = convert_vae()
+ text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ scheduler = UniPCMultistepScheduler(
+ prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
+ )
+
+ if "I2V" in args.model_type:
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
+ )
+ image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ pipe = WanImageToVideoPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
+ else:
+ pipe = WanPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
diff --git a/scripts/extract_lora_from_model.py b/scripts/extract_lora_from_model.py
new file mode 100644
index 000000000000..0e01ddea47f9
--- /dev/null
+++ b/scripts/extract_lora_from_model.py
@@ -0,0 +1,151 @@
+"""
+This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model.
+
+To make it work for other models:
+
+* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`,
+for example. (TODO: more reason to add `AutoModel`).
+* Spply path to the base checkpoint via `base_ckpt_path`.
+* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`.
+* Change the `--rank` as needed.
+
+Example usage:
+
+```bash
+python extract_lora_from_model.py \
+ --base_ckpt_path=THUDM/CogVideoX-5b \
+ --finetune_ckpt_path=finetrainers/cakeify-v0 \
+ --lora_out_path=cakeify_lora.safetensors
+```
+
+Script is adapted from
+https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py
+"""
+
+import argparse
+
+import torch
+from safetensors.torch import save_file
+from tqdm.auto import tqdm
+
+from diffusers import CogVideoXTransformer3DModel
+
+
+RANK = 64
+CLAMP_QUANTILE = 0.99
+
+
+# Comes from
+# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9
+def extract_lora(diff, rank):
+ # Important to use CUDA otherwise, very slow!
+ if torch.cuda.is_available():
+ diff = diff.to("cuda")
+
+ is_conv2d = len(diff.shape) == 4
+ kernel_size = None if not is_conv2d else diff.size()[2:4]
+ is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1)
+ out_dim, in_dim = diff.size()[0:2]
+ rank = min(rank, in_dim, out_dim)
+
+ if is_conv2d:
+ if is_conv2d_3x3:
+ diff = diff.flatten(start_dim=1)
+ else:
+ diff = diff.squeeze()
+
+ U, S, Vh = torch.linalg.svd(diff.float())
+ U = U[:, :rank]
+ S = S[:rank]
+ U = U @ torch.diag(S)
+ Vh = Vh[:rank, :]
+
+ dist = torch.cat([U.flatten(), Vh.flatten()])
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
+ low_val = -hi_val
+
+ U = U.clamp(low_val, hi_val)
+ Vh = Vh.clamp(low_val, hi_val)
+ if is_conv2d:
+ U = U.reshape(out_dim, rank, 1, 1)
+ Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
+ return (U.cpu(), Vh.cpu())
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--base_ckpt_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.",
+ )
+ parser.add_argument(
+ "--base_subfolder",
+ default="transformer",
+ type=str,
+ help="subfolder to load the base checkpoint from if any.",
+ )
+ parser.add_argument(
+ "--finetune_ckpt_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.",
+ )
+ parser.add_argument(
+ "--finetune_subfolder",
+ default=None,
+ type=str,
+ help="subfolder to load the fulle finetuned checkpoint from if any.",
+ )
+ parser.add_argument("--rank", default=64, type=int)
+ parser.add_argument("--lora_out_path", default=None, type=str, required=True)
+ args = parser.parse_args()
+
+ if not args.lora_out_path.endswith(".safetensors"):
+ raise ValueError("`lora_out_path` must end with `.safetensors`.")
+
+ return args
+
+
+@torch.no_grad()
+def main(args):
+ model_finetuned = CogVideoXTransformer3DModel.from_pretrained(
+ args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16
+ )
+ state_dict_ft = model_finetuned.state_dict()
+
+ # Change the `subfolder` as needed.
+ base_model = CogVideoXTransformer3DModel.from_pretrained(
+ args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16
+ )
+ state_dict = base_model.state_dict()
+ output_dict = {}
+
+ for k in tqdm(state_dict, desc="Extracting LoRA..."):
+ original_param = state_dict[k]
+ finetuned_param = state_dict_ft[k]
+ if len(original_param.shape) >= 2:
+ diff = finetuned_param.float() - original_param.float()
+ out = extract_lora(diff, RANK)
+ name = k
+
+ if name.endswith(".weight"):
+ name = name[: -len(".weight")]
+ down_key = "{}.lora_A.weight".format(name)
+ up_key = "{}.lora_B.weight".format(name)
+
+ output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype)
+ output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype)
+
+ prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet"
+ output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()}
+ save_file(output_dict, args.lora_out_path)
+ print(f"LoRA saved and it contains {len(output_dict)} keys.")
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/setup.py b/setup.py
index 7a8cc898e005..fdc166a81ecf 100644
--- a/setup.py
+++ b/setup.py
@@ -74,8 +74,9 @@
twine upload dist/* -r pypi
10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory. You can use the following
- Space to fetch all the commits applicable for the release: https://huggingface.co/spaces/lysandre/github-release. Repo should
- be `huggingface/diffusers`. `tag` should be the previous release tag (v0.26.1, for example), and `branch` should be
+ Space to fetch all the commits applicable for the release: https://huggingface.co/spaces/sayakpaul/auto-release-notes-diffusers.
+ It automatically fetches the correct tag and branch but also provides the option to configure them.
+ `tag` should be the previous release tag (v0.26.1, for example), and `branch` should be
the latest release branch (v0.27.0-release, for example). It denotes all commits that have happened on branch
v0.27.0-release after the tag v0.26.1 was created.
@@ -101,7 +102,7 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.23.2",
+ "huggingface-hub>=0.27.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
@@ -127,14 +128,20 @@
"GitPython<3.1.19",
"scipy",
"onnx",
+ "optimum_quanto>=0.2.6",
+ "gguf>=0.10.0",
+ "torchao>=0.7.0",
+ "bitsandbytes>=0.43.3",
"regex!=2019.12.17",
"requests",
"tensorboard",
- "torch>=1.4,<2.5.0",
+ "tiktoken>=0.7.0",
+ "torch>=1.4",
"torchvision",
"transformers>=4.41.2",
"urllib3<=2.0.0",
"black",
+ "phonemizer",
]
# this is a lookup table with items like:
@@ -225,11 +232,18 @@ def run(self):
"safetensors",
"sentencepiece",
"scipy",
+ "tiktoken",
"torchvision",
"transformers",
+ "phonemizer",
)
extras["torch"] = deps_list("torch", "accelerate")
+extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
+extras["gguf"] = deps_list("gguf", "accelerate")
+extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
+extras["torchao"] = deps_list("torchao", "accelerate")
+
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
@@ -254,7 +268,7 @@ def run(self):
setup(
name="diffusers",
- version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.33.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index a1d126f3823b..9304c34b4e01 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.31.0.dev0"
+__version__ = "0.33.0.dev0"
from typing import TYPE_CHECKING
@@ -6,14 +6,19 @@
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
+ is_accelerate_available,
+ is_bitsandbytes_available,
is_flax_available,
+ is_gguf_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_onnx_available,
+ is_optimum_quanto_available,
is_scipy_available,
is_sentencepiece_available,
is_torch_available,
+ is_torchao_available,
is_torchsde_available,
is_transformers_available,
)
@@ -28,10 +33,11 @@
_import_structure = {
"configuration_utils": ["ConfigMixin"],
+ "hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
- "quantizers.quantization_config": ["BitsAndBytesConfig"],
+ "quantizers.quantization_config": [],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
@@ -53,6 +59,54 @@
],
}
+try:
+ if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_bitsandbytes_objects
+
+ _import_structure["utils.dummy_bitsandbytes_objects"] = [
+ name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
+
+try:
+ if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_gguf_objects
+
+ _import_structure["utils.dummy_gguf_objects"] = [
+ name for name in dir(dummy_gguf_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
+
+try:
+ if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_torchao_objects
+
+ _import_structure["utils.dummy_torchao_objects"] = [
+ name for name in dir(dummy_torchao_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["quantizers.quantization_config"].append("TorchAoConfig")
+
+try:
+ if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_optimum_quanto_objects
+
+ _import_structure["utils.dummy_optimum_quanto_objects"] = [
+ name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["quantizers.quantization_config"].append("QuantoConfig")
+
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -75,36 +129,65 @@
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
+ _import_structure["hooks"].extend(
+ [
+ "FasterCacheConfig",
+ "HookRegistry",
+ "PyramidAttentionBroadcastConfig",
+ "apply_faster_cache",
+ "apply_pyramid_attention_broadcast",
+ ]
+ )
_import_structure["models"].extend(
[
+ "AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
+ "AutoencoderDC",
"AutoencoderKL",
+ "AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
+ "AutoencoderKLHunyuanVideo",
+ "AutoencoderKLLTXVideo",
+ "AutoencoderKLMagvit",
+ "AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
+ "AutoencoderKLWan",
"AutoencoderOobleck",
"AutoencoderTiny",
+ "CacheMixin",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
+ "CogView4Transformer2DModel",
+ "ConsisIDTransformer3DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
+ "ControlNetUnionModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
+ "EasyAnimateTransformer3DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
+ "HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"LatteTransformer3DModel",
+ "LTXVideoTransformer3DModel",
+ "Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
+ "MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
+ "MultiControlNetModel",
+ "OmniGenTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
+ "SanaTransformer2DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
@@ -123,6 +206,7 @@
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
+ "WanTransformer3DModel",
]
)
_import_structure["optimization"] = [
@@ -189,6 +273,7 @@
"RePaintScheduler",
"SASolverScheduler",
"SchedulerMixin",
+ "SCMScheduler",
"ScoreSdeVeScheduler",
"TCDScheduler",
"UnCLIPScheduler",
@@ -237,6 +322,7 @@
else:
_import_structure["pipelines"].extend(
[
+ "AllegroPipeline",
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
"AmusedImg2ImgPipeline",
@@ -262,16 +348,30 @@
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
+ "CogView4ControlPipeline",
+ "CogView4Pipeline",
+ "ConsisIDPipeline",
"CycleDiffusionPipeline",
+ "EasyAnimateControlPipeline",
+ "EasyAnimateInpaintPipeline",
+ "EasyAnimatePipeline",
+ "FluxControlImg2ImgPipeline",
+ "FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxControlNetPipeline",
+ "FluxControlPipeline",
+ "FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
+ "FluxPriorReduxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
+ "HunyuanSkyreelsImageToVideoPipeline",
+ "HunyuanVideoImageToVideoPipeline",
+ "HunyuanVideoPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -305,15 +405,28 @@
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
+ "LTXConditionPipeline",
+ "LTXImageToVideoPipeline",
+ "LTXPipeline",
+ "Lumina2Pipeline",
+ "Lumina2Text2ImgPipeline",
+ "LuminaPipeline",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
+ "MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
+ "MochiPipeline",
"MusicLDMPipeline",
+ "OmniGenPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
+ "ReduxImageEncoder",
+ "SanaPAGPipeline",
+ "SanaPipeline",
+ "SanaSprintPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -326,6 +439,8 @@
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
+ "StableDiffusion3PAGImg2ImgPipeline",
+ "StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline",
@@ -349,6 +464,7 @@
"StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline",
"StableDiffusionPAGImg2ImgPipeline",
+ "StableDiffusionPAGInpaintPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline",
@@ -363,6 +479,9 @@
"StableDiffusionXLControlNetPAGImg2ImgPipeline",
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLControlNetPipeline",
+ "StableDiffusionXLControlNetUnionImg2ImgPipeline",
+ "StableDiffusionXLControlNetUnionInpaintPipeline",
+ "StableDiffusionXLControlNetUnionPipeline",
"StableDiffusionXLControlNetXSPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
@@ -388,6 +507,9 @@
"VersatileDiffusionTextToImagePipeline",
"VideoToVideoSDPipeline",
"VQDiffusionPipeline",
+ "WanImageToVideoPipeline",
+ "WanPipeline",
+ "WanVideoToVideoPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
@@ -481,7 +603,7 @@
else:
- _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
+ _import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
@@ -539,7 +661,38 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
- from .quantizers.quantization_config import BitsAndBytesConfig
+
+ try:
+ if not is_bitsandbytes_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_bitsandbytes_objects import *
+ else:
+ from .quantizers.quantization_config import BitsAndBytesConfig
+
+ try:
+ if not is_gguf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_gguf_objects import *
+ else:
+ from .quantizers.quantization_config import GGUFQuantizationConfig
+
+ try:
+ if not is_torchao_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_torchao_objects import *
+ else:
+ from .quantizers.quantization_config import TorchAoConfig
+
+ try:
+ if not is_optimum_quanto_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_optimum_quanto_objects import *
+ else:
+ from .quantizers.quantization_config import QuantoConfig
try:
if not is_onnx_available():
@@ -555,35 +708,62 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
+ from .hooks import (
+ FasterCacheConfig,
+ HookRegistry,
+ PyramidAttentionBroadcastConfig,
+ apply_faster_cache,
+ apply_pyramid_attention_broadcast,
+ )
from .models import (
+ AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
+ AutoencoderDC,
AutoencoderKL,
+ AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
+ AutoencoderKLHunyuanVideo,
+ AutoencoderKLLTXVideo,
+ AutoencoderKLMagvit,
+ AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
+ AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderTiny,
+ CacheMixin,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
+ CogView4Transformer2DModel,
+ ConsisIDTransformer3DModel,
ConsistencyDecoderVAE,
ControlNetModel,
+ ControlNetUnionModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
+ EasyAnimateTransformer3DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
+ HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
LatteTransformer3DModel,
+ LTXVideoTransformer3DModel,
+ Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
+ MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
+ MultiControlNetModel,
+ OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
+ SanaTransformer2DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
@@ -601,6 +781,7 @@
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
+ WanTransformer3DModel,
)
from .optimization import (
get_constant_schedule,
@@ -666,6 +847,7 @@
RePaintScheduler,
SASolverScheduler,
SchedulerMixin,
+ SCMScheduler,
ScoreSdeVeScheduler,
TCDScheduler,
UnCLIPScheduler,
@@ -697,6 +879,7 @@
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipelines import (
+ AllegroPipeline,
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
AmusedImg2ImgPipeline,
@@ -720,16 +903,30 @@
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
+ CogView4ControlPipeline,
+ CogView4Pipeline,
+ ConsisIDPipeline,
CycleDiffusionPipeline,
+ EasyAnimateControlPipeline,
+ EasyAnimateInpaintPipeline,
+ EasyAnimatePipeline,
+ FluxControlImg2ImgPipeline,
+ FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
+ FluxControlPipeline,
+ FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
+ FluxPriorReduxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
+ HunyuanSkyreelsImageToVideoPipeline,
+ HunyuanVideoImageToVideoPipeline,
+ HunyuanVideoPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
@@ -763,15 +960,28 @@
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
+ LTXConditionPipeline,
+ LTXImageToVideoPipeline,
+ LTXPipeline,
+ Lumina2Pipeline,
+ Lumina2Text2ImgPipeline,
+ LuminaPipeline,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
+ MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
+ MochiPipeline,
MusicLDMPipeline,
+ OmniGenPipeline,
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
+ ReduxImageEncoder,
+ SanaPAGPipeline,
+ SanaPipeline,
+ SanaSprintPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
@@ -780,9 +990,11 @@
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
+ StableDiffusion3ControlNetInpaintingPipeline,
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
+ StableDiffusion3PAGImg2ImgPipeline,
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline,
@@ -806,6 +1018,7 @@
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPAGImg2ImgPipeline,
+ StableDiffusionPAGInpaintPipeline,
StableDiffusionPAGPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
@@ -820,6 +1033,9 @@
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
StableDiffusionXLControlNetPipeline,
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
+ StableDiffusionXLControlNetUnionInpaintPipeline,
+ StableDiffusionXLControlNetUnionPipeline,
StableDiffusionXLControlNetXSPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
@@ -845,6 +1061,9 @@
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VQDiffusionPipeline,
+ WanImageToVideoPipeline,
+ WanPipeline,
+ WanVideoToVideoPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
@@ -902,7 +1121,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
- from .models.controlnet_flax import FlaxControlNetModel
+ from .models.controlnets.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py
index 38542407e31f..4b8b15368c47 100644
--- a/src/diffusers/callbacks.py
+++ b/src/diffusers/callbacks.py
@@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
class SDXLCFGCutoffCallback(PipelineCallback):
"""
- Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
- `cutoff_step_index`), this callback will disable the CFG.
+ Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
- tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
+ tensor_inputs = [
+ "prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ ]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
@@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
+
+ return callback_kwargs
+
+
+class SDXLControlnetCFGCutoffCallback(PipelineCallback):
+ """
+ Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
+
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
+ """
+
+ tensor_inputs = [
+ "prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "image",
+ ]
+
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
+
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
+
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
+
+ # For Controlnet
+ image = callback_kwargs[self.tensor_inputs[3]]
+ image = image[-1:]
+
+ pipeline._guidance_scale = 0.0
+
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
+ callback_kwargs[self.tensor_inputs[3]] = image
+
return callback_kwargs
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 11d45dc64d97..f9b652bbc021 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,10 +24,10 @@
import re
from collections import OrderedDict
from pathlib import Path
-from typing import Any, Dict, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
-from huggingface_hub import create_repo, hf_hub_download
+from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
@@ -35,6 +35,7 @@
validate_hf_hub_args,
)
from requests import HTTPError
+from typing_extensions import Self
from . import __version__
from .utils import (
@@ -170,7 +171,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
- private = kwargs.pop("private", False)
+ private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -185,7 +186,9 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
)
@classmethod
- def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
+ def from_config(
+ cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
+ ) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
r"""
Instantiate a Python class from a config dictionary.
@@ -347,6 +350,7 @@ def load_config(
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
@@ -358,8 +362,15 @@ def load_config(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)
-
- if os.path.isfile(pretrained_model_name_or_path):
+ # Custom path for now
+ if dduf_entries:
+ if subfolder is not None:
+ raise ValueError(
+ "DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
+ "Please check the DDUF structure"
+ )
+ config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
+ elif os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None and os.path.isfile(
@@ -426,10 +437,8 @@ def load_config(
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file"
)
-
try:
- # Load config dict
- config_dict = cls._dict_from_json_file(config_file)
+ config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)
commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
@@ -552,9 +561,14 @@ def extract_init_dict(cls, config_dict, **kwargs):
return init_dict, unused_kwargs, hidden_config_dict
@classmethod
- def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
- with open(json_file, "r", encoding="utf-8") as reader:
- text = reader.read()
+ def _dict_from_json_file(
+ cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
+ ):
+ if dduf_entries:
+ text = dduf_entries[json_file].read_text()
+ else:
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
return json.loads(text)
def __repr__(self):
@@ -616,6 +630,20 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
+ @classmethod
+ def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
+ # paths inside a DDUF file must always be "/"
+ config_file = (
+ cls.config_name
+ if pretrained_model_name_or_path == ""
+ else "/".join([pretrained_model_name_or_path, cls.config_name])
+ )
+ if config_file not in dduf_entries:
+ raise ValueError(
+ f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
+ )
+ return config_file
+
def register_to_config(init):
r"""
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 0e421b71e48d..8ec95ed6fc8d 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -9,7 +9,7 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
- "huggingface-hub": "huggingface-hub>=0.23.2",
+ "huggingface-hub": "huggingface-hub>=0.27.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
@@ -35,12 +35,18 @@
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
+ "optimum_quanto": "optimum_quanto>=0.2.6",
+ "gguf": "gguf>=0.10.0",
+ "torchao": "torchao>=0.7.0",
+ "bitsandbytes": "bitsandbytes>=0.43.3",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
- "torch": "torch>=1.4,<2.5.0",
+ "tiktoken": "tiktoken>=0.7.0",
+ "torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.41.2",
"urllib3": "urllib3<=2.0.0",
"black": "black",
+ "phonemizer": "phonemizer",
}
diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py
new file mode 100644
index 000000000000..764ceb25b465
--- /dev/null
+++ b/src/diffusers/hooks/__init__.py
@@ -0,0 +1,9 @@
+from ..utils import is_torch_available
+
+
+if is_torch_available():
+ from .faster_cache import FasterCacheConfig, apply_faster_cache
+ from .group_offloading import apply_group_offloading
+ from .hooks import HookRegistry, ModelHook
+ from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
+ from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py
new file mode 100644
index 000000000000..634635346474
--- /dev/null
+++ b/src/diffusers/hooks/faster_cache.py
@@ -0,0 +1,653 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from dataclasses import dataclass
+from typing import Any, Callable, List, Optional, Tuple
+
+import torch
+
+from ..models.attention_processor import Attention, MochiAttention
+from ..models.modeling_outputs import Transformer2DModelOutput
+from ..utils import logging
+from .hooks import HookRegistry, ModelHook
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
+_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
+_ATTENTION_CLASSES = (Attention, MochiAttention)
+_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
+ "^blocks.*attn",
+ "^transformer_blocks.*attn",
+ "^single_transformer_blocks.*attn",
+)
+_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
+_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
+_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
+ "hidden_states",
+ "encoder_hidden_states",
+ "timestep",
+ "attention_mask",
+ "encoder_attention_mask",
+)
+
+
+@dataclass
+class FasterCacheConfig:
+ r"""
+ Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
+
+ Attributes:
+ spatial_attention_block_skip_range (`int`, defaults to `2`):
+ Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
+ be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
+ states again.
+ temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
+ Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
+ be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
+ states again.
+ spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
+ The timestep range within which the spatial attention computation can be skipped without a significant loss
+ in quality. This is to be determined by the user based on the underlying model. The first value in the
+ tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
+ denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
+ timestep 0). For the default values, this would mean that the spatial attention computation skipping will
+ be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
+ process.
+ temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
+ The timestep range within which the temporal attention computation can be skipped without a significant
+ loss in quality. This is to be determined by the user based on the underlying model. The first value in the
+ tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
+ denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
+ timestep 0).
+ low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
+ The timestep range within which the low frequency weight scaling update is applied. The first value in the
+ tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
+ function for the update is called only within this range.
+ high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
+ The timestep range within which the high frequency weight scaling update is applied. The first value in the
+ tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
+ function for the update is called only within this range.
+ alpha_low_frequency (`float`, defaults to `1.1`):
+ The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
+ the conditional branch outputs.
+ alpha_high_frequency (`float`, defaults to `1.1`):
+ The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
+ from the conditional branch outputs.
+ unconditional_batch_skip_range (`int`, defaults to `5`):
+ Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
+ computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
+ computing the new unconditional branch states again.
+ unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
+ The timestep range within which the unconditional branch computation can be skipped without a significant
+ loss in quality. This is to be determined by the user based on the underlying model. The first value in the
+ tuple is the lower bound and the second value is the upper bound.
+ spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
+ The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
+ of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
+ partial layer names, or regex patterns. Matching will always be done using a regex match.
+ temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
+ The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
+ of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
+ partial layer names, or regex patterns. Matching will always be done using a regex match.
+ attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
+ The callback function to determine the weight to scale the attention outputs by. This function should take
+ the attention module as input and return a float value. This is used to approximate the unconditional
+ branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
+ Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
+ progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
+ the number of inference steps and underlying model behaviour as denoising progresses.
+ low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
+ The callback function to determine the weight to scale the low frequency updates by. If not provided, the
+ default weight is 1.1 for timesteps within the range specified (as described in the paper).
+ high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
+ The callback function to determine the weight to scale the high frequency updates by. If not provided, the
+ default weight is 1.1 for timesteps within the range specified (as described in the paper).
+ tensor_format (`str`, defaults to `"BCFHW"`):
+ The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
+ used to split individual latent frames in order for low and high frequency components to be computed.
+ is_guidance_distilled (`bool`, defaults to `False`):
+ Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
+ applied at the denoiser-level to skip the unconditional branch computation (as there is none).
+ _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
+ The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
+ conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
+ split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
+ names that contain the batchwise-concatenated unconditional and conditional inputs.
+ """
+
+ # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
+ # after some testing. We default to 2 if these parameters are not provided.
+ spatial_attention_block_skip_range: int = 2
+ temporal_attention_block_skip_range: Optional[int] = None
+
+ spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
+ temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
+
+ # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
+ low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
+ high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
+
+ # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
+ alpha_low_frequency: float = 1.1
+ alpha_high_frequency: float = 1.1
+
+ # n as described in CFG-Cache explanation in the paper - dependant on the model
+ unconditional_batch_skip_range: int = 5
+ unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
+
+ spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
+ temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
+
+ attention_weight_callback: Callable[[torch.nn.Module], float] = None
+ low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
+ high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
+
+ tensor_format: str = "BCFHW"
+ is_guidance_distilled: bool = False
+
+ current_timestep_callback: Callable[[], int] = None
+
+ _unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
+
+ def __repr__(self) -> str:
+ return (
+ f"FasterCacheConfig(\n"
+ f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
+ f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
+ f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
+ f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
+ f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
+ f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
+ f" alpha_low_frequency={self.alpha_low_frequency},\n"
+ f" alpha_high_frequency={self.alpha_high_frequency},\n"
+ f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
+ f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
+ f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
+ f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
+ f" tensor_format={self.tensor_format},\n"
+ f")"
+ )
+
+
+class FasterCacheDenoiserState:
+ r"""
+ State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
+ """
+
+ def __init__(self) -> None:
+ self.iteration: int = 0
+ self.low_frequency_delta: torch.Tensor = None
+ self.high_frequency_delta: torch.Tensor = None
+
+ def reset(self):
+ self.iteration = 0
+ self.low_frequency_delta = None
+ self.high_frequency_delta = None
+
+
+class FasterCacheBlockState:
+ r"""
+ State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
+ applied to will have an instance of this state.
+ """
+
+ def __init__(self) -> None:
+ self.iteration: int = 0
+ self.batch_size: int = None
+ self.cache: Tuple[torch.Tensor, torch.Tensor] = None
+
+ def reset(self):
+ self.iteration = 0
+ self.batch_size = None
+ self.cache = None
+
+
+class FasterCacheDenoiserHook(ModelHook):
+ _is_stateful = True
+
+ def __init__(
+ self,
+ unconditional_batch_skip_range: int,
+ unconditional_batch_timestep_skip_range: Tuple[int, int],
+ tensor_format: str,
+ is_guidance_distilled: bool,
+ uncond_cond_input_kwargs_identifiers: List[str],
+ current_timestep_callback: Callable[[], int],
+ low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
+ high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
+ ) -> None:
+ super().__init__()
+
+ self.unconditional_batch_skip_range = unconditional_batch_skip_range
+ self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
+ # We can't easily detect what args are to be split in unconditional and conditional branches. We
+ # can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
+ # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
+ # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
+ self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
+ self.tensor_format = tensor_format
+ self.is_guidance_distilled = is_guidance_distilled
+
+ self.current_timestep_callback = current_timestep_callback
+ self.low_frequency_weight_callback = low_frequency_weight_callback
+ self.high_frequency_weight_callback = high_frequency_weight_callback
+
+ def initialize_hook(self, module):
+ self.state = FasterCacheDenoiserState()
+ return module
+
+ @staticmethod
+ def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
+ # followed by conditional inputs.
+ _, cond = input.chunk(2, dim=0)
+ return cond
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
+ # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
+ # requirements for skipping the unconditional branch are met as described in the paper.
+ # We skip the unconditional branch only if the following conditions are met:
+ # 1. We have completed at least one iteration of the denoiser
+ # 2. The current timestep is within the range specified by the user. This is the optimal timestep range
+ # where approximating the unconditional branch from the computation of the conditional branch is possible
+ # without a significant loss in quality.
+ # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
+ # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
+ is_within_timestep_range = (
+ self.unconditional_batch_timestep_skip_range[0]
+ < self.current_timestep_callback()
+ < self.unconditional_batch_timestep_skip_range[1]
+ )
+ should_skip_uncond = (
+ self.state.iteration > 0
+ and is_within_timestep_range
+ and self.state.iteration % self.unconditional_batch_skip_range != 0
+ and not self.is_guidance_distilled
+ )
+
+ if should_skip_uncond:
+ is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
+ if is_any_kwarg_uncond:
+ logger.debug("FasterCache - Skipping unconditional branch computation")
+ args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
+ kwargs = {
+ k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
+ for k, v in kwargs.items()
+ }
+
+ output = self.fn_ref.original_forward(*args, **kwargs)
+
+ if self.is_guidance_distilled:
+ self.state.iteration += 1
+ return output
+
+ if torch.is_tensor(output):
+ hidden_states = output
+ elif isinstance(output, (tuple, Transformer2DModelOutput)):
+ hidden_states = output[0]
+
+ batch_size = hidden_states.size(0)
+
+ if should_skip_uncond:
+ self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
+ module
+ )
+ self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
+ module
+ )
+
+ if self.tensor_format == "BCFHW":
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
+ hidden_states = hidden_states.flatten(0, 1)
+
+ low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
+
+ # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
+ low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
+ high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
+ uncond_freq = low_freq_uncond + high_freq_uncond
+
+ uncond_states = torch.fft.ifftshift(uncond_freq)
+ uncond_states = torch.fft.ifft2(uncond_states).real
+
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
+ uncond_states = uncond_states.unflatten(0, (batch_size, -1))
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1))
+ if self.tensor_format == "BCFHW":
+ uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+
+ # Concatenate the approximated unconditional and predicted conditional branches
+ uncond_states = uncond_states.to(hidden_states.dtype)
+ hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
+ else:
+ uncond_states, cond_states = hidden_states.chunk(2, dim=0)
+ if self.tensor_format == "BCFHW":
+ uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
+ cond_states = cond_states.permute(0, 2, 1, 3, 4)
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
+ uncond_states = uncond_states.flatten(0, 1)
+ cond_states = cond_states.flatten(0, 1)
+
+ low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
+ low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
+ self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
+ self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
+
+ self.state.iteration += 1
+ if torch.is_tensor(output):
+ output = hidden_states
+ elif isinstance(output, tuple):
+ output = (hidden_states, *output[1:])
+ else:
+ output.sample = hidden_states
+
+ return output
+
+ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
+ self.state.reset()
+ return module
+
+
+class FasterCacheBlockHook(ModelHook):
+ _is_stateful = True
+
+ def __init__(
+ self,
+ block_skip_range: int,
+ timestep_skip_range: Tuple[int, int],
+ is_guidance_distilled: bool,
+ weight_callback: Callable[[torch.nn.Module], float],
+ current_timestep_callback: Callable[[], int],
+ ) -> None:
+ super().__init__()
+
+ self.block_skip_range = block_skip_range
+ self.timestep_skip_range = timestep_skip_range
+ self.is_guidance_distilled = is_guidance_distilled
+
+ self.weight_callback = weight_callback
+ self.current_timestep_callback = current_timestep_callback
+
+ def initialize_hook(self, module):
+ self.state = FasterCacheBlockState()
+ return module
+
+ def _compute_approximated_attention_output(
+ self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
+ ) -> torch.Tensor:
+ if t_2_output.size(0) != batch_size:
+ # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
+ # take the conditional branch outputs.
+ assert t_2_output.size(0) == 2 * batch_size
+ t_2_output = t_2_output[batch_size:]
+ if t_output.size(0) != batch_size:
+ # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
+ # take the conditional branch outputs.
+ assert t_output.size(0) == 2 * batch_size
+ t_output = t_output[batch_size:]
+ return t_output + (t_output - t_2_output) * weight
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
+ batch_size = [
+ *[arg.size(0) for arg in args if torch.is_tensor(arg)],
+ *[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
+ ][0]
+ if self.state.batch_size is None:
+ # Will be updated on first forward pass through the denoiser
+ self.state.batch_size = batch_size
+
+ # If we have to skip due to the skip conditions, then let's skip as expected.
+ # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
+ # is because the expected output shapes of attention layer will not match if we only return values from
+ # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
+ # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
+ # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
+ is_within_timestep_range = (
+ self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
+ )
+ if not is_within_timestep_range:
+ should_skip_attention = False
+ else:
+ should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
+ should_skip_attention = not should_compute_attention
+ if should_skip_attention:
+ should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
+
+ if should_skip_attention:
+ logger.debug("FasterCache - Skipping attention and using approximation")
+ if torch.is_tensor(self.state.cache[-1]):
+ t_2_output, t_output = self.state.cache
+ weight = self.weight_callback(module)
+ output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
+ else:
+ # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
+ # Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
+ # In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
+ # a forward pass of the block. We need to compute the approximated output for each of these tensors.
+ # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
+ # allows us to compute the approximated attention output for each tensor in the cache.
+ output = ()
+ for t_2_output, t_output in zip(*self.state.cache):
+ result = self._compute_approximated_attention_output(
+ t_2_output, t_output, self.weight_callback(module), batch_size
+ )
+ output += (result,)
+ else:
+ logger.debug("FasterCache - Computing attention")
+ output = self.fn_ref.original_forward(*args, **kwargs)
+
+ # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
+ # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
+ # both cases.
+ if torch.is_tensor(output):
+ cache_output = output
+ if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
+ # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
+ # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
+ cache_output = cache_output.chunk(2, dim=0)[1]
+ else:
+ # Cache all return values and perform the same operation as above
+ cache_output = ()
+ for out in output:
+ if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
+ out = out.chunk(2, dim=0)[1]
+ cache_output += (out,)
+
+ if self.state.cache is None:
+ self.state.cache = [cache_output, cache_output]
+ else:
+ self.state.cache = [self.state.cache[-1], cache_output]
+
+ self.state.iteration += 1
+ return output
+
+ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
+ self.state.reset()
+ return module
+
+
+def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
+ r"""
+ Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
+
+ Args:
+ pipeline (`DiffusionPipeline`):
+ The diffusion pipeline to apply FasterCache to.
+ config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
+ The configuration to use for FasterCache.
+
+ Example:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
+
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> config = FasterCacheConfig(
+ ... spatial_attention_block_skip_range=2,
+ ... spatial_attention_timestep_skip_range=(-1, 681),
+ ... low_frequency_weight_update_timestep_range=(99, 641),
+ ... high_frequency_weight_update_timestep_range=(-1, 301),
+ ... spatial_attention_block_identifiers=["transformer_blocks"],
+ ... attention_weight_callback=lambda _: 0.3,
+ ... tensor_format="BFCHW",
+ ... )
+ >>> apply_faster_cache(pipe.transformer, config)
+ ```
+ """
+
+ logger.warning(
+ "FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
+ "The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
+ "https://github.com/huggingface/diffusers/issues."
+ )
+
+ if config.attention_weight_callback is None:
+ # If the user has not provided a weight callback, we default to 0.5 for all timesteps.
+ # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
+ # this depends from model-to-model. It is required by the user to provide a weight callback if they want to
+ # use a different weight function. Defaulting to 0.5 works well in practice for most cases.
+ logger.warning(
+ "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
+ )
+ config.attention_weight_callback = lambda _: 0.5
+
+ if config.low_frequency_weight_callback is None:
+ logger.debug(
+ "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
+ )
+
+ def low_frequency_weight_callback(module: torch.nn.Module) -> float:
+ is_within_range = (
+ config.low_frequency_weight_update_timestep_range[0]
+ < config.current_timestep_callback()
+ < config.low_frequency_weight_update_timestep_range[1]
+ )
+ return config.alpha_low_frequency if is_within_range else 1.0
+
+ config.low_frequency_weight_callback = low_frequency_weight_callback
+
+ if config.high_frequency_weight_callback is None:
+ logger.debug(
+ "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
+ )
+
+ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
+ is_within_range = (
+ config.high_frequency_weight_update_timestep_range[0]
+ < config.current_timestep_callback()
+ < config.high_frequency_weight_update_timestep_range[1]
+ )
+ return config.alpha_high_frequency if is_within_range else 1.0
+
+ config.high_frequency_weight_callback = high_frequency_weight_callback
+
+ supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
+ if config.tensor_format not in supported_tensor_formats:
+ raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
+
+ _apply_faster_cache_on_denoiser(module, config)
+
+ for name, submodule in module.named_modules():
+ if not isinstance(submodule, _ATTENTION_CLASSES):
+ continue
+ if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
+ _apply_faster_cache_on_attention_class(name, submodule, config)
+
+
+def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
+ hook = FasterCacheDenoiserHook(
+ config.unconditional_batch_skip_range,
+ config.unconditional_batch_timestep_skip_range,
+ config.tensor_format,
+ config.is_guidance_distilled,
+ config._unconditional_conditional_input_kwargs_identifiers,
+ config.current_timestep_callback,
+ config.low_frequency_weight_callback,
+ config.high_frequency_weight_callback,
+ )
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
+
+
+def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
+ is_spatial_self_attention = (
+ any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
+ and config.spatial_attention_block_skip_range is not None
+ and not getattr(module, "is_cross_attention", False)
+ )
+ is_temporal_self_attention = (
+ any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
+ and config.temporal_attention_block_skip_range is not None
+ and not module.is_cross_attention
+ )
+
+ block_skip_range, timestep_skip_range, block_type = None, None, None
+ if is_spatial_self_attention:
+ block_skip_range = config.spatial_attention_block_skip_range
+ timestep_skip_range = config.spatial_attention_timestep_skip_range
+ block_type = "spatial"
+ elif is_temporal_self_attention:
+ block_skip_range = config.temporal_attention_block_skip_range
+ timestep_skip_range = config.temporal_attention_timestep_skip_range
+ block_type = "temporal"
+
+ if block_skip_range is None or timestep_skip_range is None:
+ logger.debug(
+ f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
+ f"not match any of the required criteria for spatial or temporal attention layers. Note, "
+ f"however, that this layer may still be valid for applying PAB. Please specify the correct "
+ f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
+ f"function to apply FasterCache to this layer."
+ )
+ return
+
+ logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
+ hook = FasterCacheBlockHook(
+ block_skip_range,
+ timestep_skip_range,
+ config.is_guidance_distilled,
+ config.attention_weight_callback,
+ config.current_timestep_callback,
+ )
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
+
+
+# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
+@torch.no_grad()
+def _split_low_high_freq(x):
+ fft = torch.fft.fft2(x)
+ fft_shifted = torch.fft.fftshift(fft)
+ height, width = x.shape[-2:]
+ radius = min(height, width) // 5
+
+ y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
+ center_x, center_y = width // 2, height // 2
+ mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
+
+ low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
+ high_freq_mask = ~low_freq_mask
+
+ low_freq_fft = fft_shifted * low_freq_mask
+ high_freq_fft = fft_shifted * high_freq_mask
+
+ return low_freq_fft, high_freq_fft
diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py
new file mode 100644
index 000000000000..4c1d354a0f59
--- /dev/null
+++ b/src/diffusers/hooks/group_offloading.py
@@ -0,0 +1,735 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import contextmanager, nullcontext
+from typing import Dict, List, Optional, Set, Tuple
+
+import torch
+
+from ..utils import get_logger, is_accelerate_available
+from .hooks import HookRegistry, ModelHook
+
+
+if is_accelerate_available():
+ from accelerate.hooks import AlignDevicesHook, CpuOffload
+ from accelerate.utils import send_to_device
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+# fmt: off
+_GROUP_OFFLOADING = "group_offloading"
+_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
+_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
+
+_SUPPORTED_PYTORCH_LAYERS = (
+ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
+ torch.nn.Linear,
+ # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
+ # because of double invocation of the same norm layer in CogVideoXLayerNorm
+)
+# fmt: on
+
+
+class ModuleGroup:
+ def __init__(
+ self,
+ modules: List[torch.nn.Module],
+ offload_device: torch.device,
+ onload_device: torch.device,
+ offload_leader: torch.nn.Module,
+ onload_leader: Optional[torch.nn.Module] = None,
+ parameters: Optional[List[torch.nn.Parameter]] = None,
+ buffers: Optional[List[torch.Tensor]] = None,
+ non_blocking: bool = False,
+ stream: Optional[torch.cuda.Stream] = None,
+ low_cpu_mem_usage=False,
+ onload_self: bool = True,
+ ) -> None:
+ self.modules = modules
+ self.offload_device = offload_device
+ self.onload_device = onload_device
+ self.offload_leader = offload_leader
+ self.onload_leader = onload_leader
+ self.parameters = parameters or []
+ self.buffers = buffers or []
+ self.non_blocking = non_blocking or stream is not None
+ self.stream = stream
+ self.onload_self = onload_self
+ self.low_cpu_mem_usage = low_cpu_mem_usage
+
+ self.cpu_param_dict = self._init_cpu_param_dict()
+
+ def _init_cpu_param_dict(self):
+ cpu_param_dict = {}
+ if self.stream is None:
+ return cpu_param_dict
+
+ for module in self.modules:
+ for param in module.parameters():
+ cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
+ for buffer in module.buffers():
+ cpu_param_dict[buffer] = (
+ buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
+ )
+
+ for param in self.parameters:
+ cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
+
+ for buffer in self.buffers:
+ cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
+
+ return cpu_param_dict
+
+ @contextmanager
+ def _pinned_memory_tensors(self):
+ pinned_dict = {}
+ try:
+ for param, tensor in self.cpu_param_dict.items():
+ if not tensor.is_pinned():
+ pinned_dict[param] = tensor.pin_memory()
+ else:
+ pinned_dict[param] = tensor
+
+ yield pinned_dict
+
+ finally:
+ pinned_dict = None
+
+ def onload_(self):
+ r"""Onloads the group of modules to the onload_device."""
+ context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
+ if self.stream is not None:
+ # Wait for previous Host->Device transfer to complete
+ self.stream.synchronize()
+
+ with context:
+ if self.stream is not None:
+ with self._pinned_memory_tensors() as pinned_memory:
+ for group_module in self.modules:
+ for param in group_module.parameters():
+ param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
+ for buffer in group_module.buffers():
+ buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
+
+ for param in self.parameters:
+ param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
+
+ for buffer in self.buffers:
+ buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
+
+ else:
+ for group_module in self.modules:
+ for param in group_module.parameters():
+ param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
+ for buffer in group_module.buffers():
+ buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
+
+ for param in self.parameters:
+ param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
+
+ for buffer in self.buffers:
+ buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
+
+ def offload_(self):
+ r"""Offloads the group of modules to the offload_device."""
+ if self.stream is not None:
+ torch.cuda.current_stream().synchronize()
+ for group_module in self.modules:
+ for param in group_module.parameters():
+ param.data = self.cpu_param_dict[param]
+ for param in self.parameters:
+ param.data = self.cpu_param_dict[param]
+ for buffer in self.buffers:
+ buffer.data = self.cpu_param_dict[buffer]
+
+ else:
+ for group_module in self.modules:
+ group_module.to(self.offload_device, non_blocking=self.non_blocking)
+ for param in self.parameters:
+ param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
+ for buffer in self.buffers:
+ buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
+
+
+class GroupOffloadingHook(ModelHook):
+ r"""
+ A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
+ computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
+ module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
+ group is responsible for onloading the current module group.
+ """
+
+ _is_stateful = False
+
+ def __init__(
+ self,
+ group: ModuleGroup,
+ next_group: Optional[ModuleGroup] = None,
+ ) -> None:
+ self.group = group
+ self.next_group = next_group
+
+ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
+ if self.group.offload_leader == module:
+ self.group.offload_()
+ return module
+
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
+ # If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
+ # method is the onload_leader of the group.
+ if self.group.onload_leader is None:
+ self.group.onload_leader = module
+
+ # If the current module is the onload_leader of the group, we onload the group if it is supposed
+ # to onload itself. In the case of using prefetching with streams, we onload the next group if
+ # it is not supposed to onload itself.
+ if self.group.onload_leader == module:
+ if self.group.onload_self:
+ self.group.onload_()
+ if self.next_group is not None and not self.next_group.onload_self:
+ self.next_group.onload_()
+
+ args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
+ kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
+ return args, kwargs
+
+ def post_forward(self, module: torch.nn.Module, output):
+ if self.group.offload_leader == module:
+ self.group.offload_()
+ return output
+
+
+class LazyPrefetchGroupOffloadingHook(ModelHook):
+ r"""
+ A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
+ This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
+ invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
+ prefetching groups in the correct order.
+ """
+
+ _is_stateful = False
+
+ def __init__(self):
+ self.execution_order: List[Tuple[str, torch.nn.Module]] = []
+ self._layer_execution_tracker_module_names = set()
+
+ def initialize_hook(self, module):
+ def make_execution_order_update_callback(current_name, current_submodule):
+ def callback():
+ logger.debug(f"Adding {current_name} to the execution order")
+ self.execution_order.append((current_name, current_submodule))
+
+ return callback
+
+ # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
+ # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
+ # layers are executed during the forward pass.
+ for name, submodule in module.named_modules():
+ if name == "" or not hasattr(submodule, "_diffusers_hook"):
+ continue
+
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
+ group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
+
+ if group_offloading_hook is not None:
+ # For the first forward pass, we have to load in a blocking manner
+ group_offloading_hook.group.non_blocking = False
+ layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
+ registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
+ self._layer_execution_tracker_module_names.add(name)
+
+ return module
+
+ def post_forward(self, module, output):
+ # At this point, for the current modules' submodules, we know the execution order of the layers. We can now
+ # remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
+ # group offloading hook.
+ num_executed = len(self.execution_order)
+ execution_order_module_names = {name for name, _ in self.execution_order}
+
+ # It may be possible that some layers were not executed during the forward pass. This can happen if the layer
+ # is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
+ # may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
+ # if the missing layers end up being executed in the future.
+ if execution_order_module_names != self._layer_execution_tracker_module_names:
+ unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
+ logger.warning(
+ "It seems like some layers were not executed during the forward pass. This may lead to problems when "
+ "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
+ "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
+ f"{unexecuted_layers=}"
+ )
+
+ # Remove the layer execution tracker hooks from the submodules
+ base_module_registry = module._diffusers_hook
+ registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
+ group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
+
+ for i in range(num_executed):
+ registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
+
+ # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
+ base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
+
+ # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
+ # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
+ # see the benefits of prefetching.
+ for hook in group_offloading_hooks:
+ hook.group.non_blocking = True
+
+ # Set required attributes for prefetching
+ if num_executed > 0:
+ base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
+ base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
+ base_module_group_offloading_hook.next_group.onload_self = False
+
+ for i in range(num_executed - 1):
+ name1, _ = self.execution_order[i]
+ name2, _ = self.execution_order[i + 1]
+ logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
+ group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
+ group_offloading_hooks[i].next_group.onload_self = False
+
+ return output
+
+
+class LayerExecutionTrackerHook(ModelHook):
+ r"""
+ A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
+ LazyPrefetchGroupOffloadingHook to update the execution order.
+ """
+
+ _is_stateful = False
+
+ def __init__(self, execution_order_update_callback):
+ self.execution_order_update_callback = execution_order_update_callback
+
+ def pre_forward(self, module, *args, **kwargs):
+ self.execution_order_update_callback()
+ return args, kwargs
+
+
+def apply_group_offloading(
+ module: torch.nn.Module,
+ onload_device: torch.device,
+ offload_device: torch.device = torch.device("cpu"),
+ offload_type: str = "block_level",
+ num_blocks_per_group: Optional[int] = None,
+ non_blocking: bool = False,
+ use_stream: bool = False,
+ low_cpu_mem_usage: bool = False,
+) -> None:
+ r"""
+ Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
+ where it is beneficial, we need to first provide some context on how other supported offloading methods work.
+
+ Typically, offloading is done at two levels:
+ - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
+ works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
+ when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
+ but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
+ the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
+ pass.
+ - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
+ works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
+ onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
+ memory, but can be slower due to the excessive number of device synchronizations.
+
+ Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
+ (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
+ offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
+ reduced.
+
+ Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
+ overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
+ is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
+ the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
+ Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module to which group offloading is applied.
+ onload_device (`torch.device`):
+ The device to which the group of modules are onloaded.
+ offload_device (`torch.device`, defaults to `torch.device("cpu")`):
+ The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
+ offload_type (`str`, defaults to "block_level"):
+ The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
+ "block_level".
+ num_blocks_per_group (`int`, *optional*):
+ The number of blocks per group when using offload_type="block_level". This is required when using
+ offload_type="block_level".
+ non_blocking (`bool`, defaults to `False`):
+ If True, offloading and onloading is done with non-blocking data transfer.
+ use_stream (`bool`, defaults to `False`):
+ If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
+ overlapping computation and data transfer.
+ low_cpu_mem_usage (`bool`, defaults to `False`):
+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
+ option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
+ the CPU memory is a bottleneck but may counteract the benefits of using streams.
+
+ Example:
+ ```python
+ >>> from diffusers import CogVideoXTransformer3DModel
+ >>> from diffusers.hooks import apply_group_offloading
+
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
+ ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+
+ >>> apply_group_offloading(
+ ... transformer,
+ ... onload_device=torch.device("cuda"),
+ ... offload_device=torch.device("cpu"),
+ ... offload_type="block_level",
+ ... num_blocks_per_group=2,
+ ... use_stream=True,
+ ... )
+ ```
+ """
+
+ stream = None
+ if use_stream:
+ if torch.cuda.is_available():
+ stream = torch.cuda.Stream()
+ else:
+ raise ValueError("Using streams for data transfer requires a CUDA device.")
+
+ _raise_error_if_accelerate_model_or_sequential_hook_present(module)
+
+ if offload_type == "block_level":
+ if num_blocks_per_group is None:
+ raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
+
+ _apply_group_offloading_block_level(
+ module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
+ )
+ elif offload_type == "leaf_level":
+ _apply_group_offloading_leaf_level(
+ module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
+ )
+ else:
+ raise ValueError(f"Unsupported offload_type: {offload_type}")
+
+
+def _apply_group_offloading_block_level(
+ module: torch.nn.Module,
+ num_blocks_per_group: int,
+ offload_device: torch.device,
+ onload_device: torch.device,
+ non_blocking: bool,
+ stream: Optional[torch.cuda.Stream] = None,
+ low_cpu_mem_usage: bool = False,
+) -> None:
+ r"""
+ This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
+ the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module to which group offloading is applied.
+ offload_device (`torch.device`):
+ The device to which the group of modules are offloaded. This should typically be the CPU.
+ onload_device (`torch.device`):
+ The device to which the group of modules are onloaded.
+ non_blocking (`bool`):
+ If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
+ and data transfer.
+ stream (`torch.cuda.Stream`, *optional*):
+ If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
+ for overlapping computation and data transfer.
+ """
+
+ # Create module groups for ModuleList and Sequential blocks
+ modules_with_group_offloading = set()
+ unmatched_modules = []
+ matched_module_groups = []
+ for name, submodule in module.named_children():
+ if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
+ unmatched_modules.append((name, submodule))
+ modules_with_group_offloading.add(name)
+ continue
+
+ for i in range(0, len(submodule), num_blocks_per_group):
+ current_modules = submodule[i : i + num_blocks_per_group]
+ group = ModuleGroup(
+ modules=current_modules,
+ offload_device=offload_device,
+ onload_device=onload_device,
+ offload_leader=current_modules[-1],
+ onload_leader=current_modules[0],
+ non_blocking=non_blocking,
+ stream=stream,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ onload_self=stream is None,
+ )
+ matched_module_groups.append(group)
+ for j in range(i, i + len(current_modules)):
+ modules_with_group_offloading.add(f"{name}.{j}")
+
+ # Apply group offloading hooks to the module groups
+ for i, group in enumerate(matched_module_groups):
+ next_group = (
+ matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
+ )
+
+ for group_module in group.modules:
+ _apply_group_offloading_hook(group_module, group, next_group)
+
+ # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
+ # when the forward pass of this module is called. This is because the top-level module is not
+ # part of any group (as doing so would lead to no VRAM savings).
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
+ parameters = [param for _, param in parameters]
+ buffers = [buffer for _, buffer in buffers]
+
+ # Create a group for the unmatched submodules of the top-level module so that they are on the correct
+ # device when the forward pass is called.
+ unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
+ unmatched_group = ModuleGroup(
+ modules=unmatched_modules,
+ offload_device=offload_device,
+ onload_device=onload_device,
+ offload_leader=module,
+ onload_leader=module,
+ parameters=parameters,
+ buffers=buffers,
+ non_blocking=False,
+ stream=None,
+ onload_self=True,
+ )
+ next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
+ _apply_group_offloading_hook(module, unmatched_group, next_group)
+
+
+def _apply_group_offloading_leaf_level(
+ module: torch.nn.Module,
+ offload_device: torch.device,
+ onload_device: torch.device,
+ non_blocking: bool,
+ stream: Optional[torch.cuda.Stream] = None,
+ low_cpu_mem_usage: bool = False,
+) -> None:
+ r"""
+ This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
+ requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
+ synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
+ reduce memory usage without any performance degradation.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module to which group offloading is applied.
+ offload_device (`torch.device`):
+ The device to which the group of modules are offloaded. This should typically be the CPU.
+ onload_device (`torch.device`):
+ The device to which the group of modules are onloaded.
+ non_blocking (`bool`):
+ If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
+ and data transfer.
+ stream (`torch.cuda.Stream`, *optional*):
+ If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
+ for overlapping computation and data transfer.
+ """
+
+ # Create module groups for leaf modules and apply group offloading hooks
+ modules_with_group_offloading = set()
+ for name, submodule in module.named_modules():
+ if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
+ continue
+ group = ModuleGroup(
+ modules=[submodule],
+ offload_device=offload_device,
+ onload_device=onload_device,
+ offload_leader=submodule,
+ onload_leader=submodule,
+ non_blocking=non_blocking,
+ stream=stream,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ onload_self=True,
+ )
+ _apply_group_offloading_hook(submodule, group, None)
+ modules_with_group_offloading.add(name)
+
+ # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
+ # of the module is called
+ module_dict = dict(module.named_modules())
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
+
+ # Find closest module parent for each parameter and buffer, and attach group hooks
+ parent_to_parameters = {}
+ for name, param in parameters:
+ parent_name = _find_parent_module_in_module_dict(name, module_dict)
+ if parent_name in parent_to_parameters:
+ parent_to_parameters[parent_name].append(param)
+ else:
+ parent_to_parameters[parent_name] = [param]
+
+ parent_to_buffers = {}
+ for name, buffer in buffers:
+ parent_name = _find_parent_module_in_module_dict(name, module_dict)
+ if parent_name in parent_to_buffers:
+ parent_to_buffers[parent_name].append(buffer)
+ else:
+ parent_to_buffers[parent_name] = [buffer]
+
+ parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
+ for name in parent_names:
+ parameters = parent_to_parameters.get(name, [])
+ buffers = parent_to_buffers.get(name, [])
+ parent_module = module_dict[name]
+ assert getattr(parent_module, "_diffusers_hook", None) is None
+ group = ModuleGroup(
+ modules=[],
+ offload_device=offload_device,
+ onload_device=onload_device,
+ offload_leader=parent_module,
+ onload_leader=parent_module,
+ parameters=parameters,
+ buffers=buffers,
+ non_blocking=non_blocking,
+ stream=stream,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ onload_self=True,
+ )
+ _apply_group_offloading_hook(parent_module, group, None)
+
+ if stream is not None:
+ # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
+ # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
+ # execution order and apply prefetching in the correct order.
+ unmatched_group = ModuleGroup(
+ modules=[],
+ offload_device=offload_device,
+ onload_device=onload_device,
+ offload_leader=module,
+ onload_leader=module,
+ parameters=None,
+ buffers=None,
+ non_blocking=False,
+ stream=None,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ onload_self=True,
+ )
+ _apply_lazy_group_offloading_hook(module, unmatched_group, None)
+
+
+def _apply_group_offloading_hook(
+ module: torch.nn.Module,
+ group: ModuleGroup,
+ next_group: Optional[ModuleGroup] = None,
+) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+
+ # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
+ # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
+ if registry.get_hook(_GROUP_OFFLOADING) is None:
+ hook = GroupOffloadingHook(group, next_group)
+ registry.register_hook(hook, _GROUP_OFFLOADING)
+
+
+def _apply_lazy_group_offloading_hook(
+ module: torch.nn.Module,
+ group: ModuleGroup,
+ next_group: Optional[ModuleGroup] = None,
+) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+
+ # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
+ # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
+ if registry.get_hook(_GROUP_OFFLOADING) is None:
+ hook = GroupOffloadingHook(group, next_group)
+ registry.register_hook(hook, _GROUP_OFFLOADING)
+
+ lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
+ registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
+
+
+def _gather_parameters_with_no_group_offloading_parent(
+ module: torch.nn.Module, modules_with_group_offloading: Set[str]
+) -> List[torch.nn.Parameter]:
+ parameters = []
+ for name, parameter in module.named_parameters():
+ has_parent_with_group_offloading = False
+ atoms = name.split(".")
+ while len(atoms) > 0:
+ parent_name = ".".join(atoms)
+ if parent_name in modules_with_group_offloading:
+ has_parent_with_group_offloading = True
+ break
+ atoms.pop()
+ if not has_parent_with_group_offloading:
+ parameters.append((name, parameter))
+ return parameters
+
+
+def _gather_buffers_with_no_group_offloading_parent(
+ module: torch.nn.Module, modules_with_group_offloading: Set[str]
+) -> List[torch.Tensor]:
+ buffers = []
+ for name, buffer in module.named_buffers():
+ has_parent_with_group_offloading = False
+ atoms = name.split(".")
+ while len(atoms) > 0:
+ parent_name = ".".join(atoms)
+ if parent_name in modules_with_group_offloading:
+ has_parent_with_group_offloading = True
+ break
+ atoms.pop()
+ if not has_parent_with_group_offloading:
+ buffers.append((name, buffer))
+ return buffers
+
+
+def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
+ atoms = name.split(".")
+ while len(atoms) > 0:
+ parent_name = ".".join(atoms)
+ if parent_name in module_dict:
+ return parent_name
+ atoms.pop()
+ return ""
+
+
+def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
+ if not is_accelerate_available():
+ return
+ for name, submodule in module.named_modules():
+ if not hasattr(submodule, "_hf_hook"):
+ continue
+ if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
+ raise ValueError(
+ f"Cannot apply group offloading to a module that is already applying an alternative "
+ f"offloading strategy from Accelerate. If you want to apply group offloading, please "
+ f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
+ )
+
+
+def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
+ for submodule in module.modules():
+ if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
+ return True
+ return False
+
+
+def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
+ for submodule in module.modules():
+ if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
+ return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
+ raise ValueError("Group offloading is not enabled for the provided module.")
diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py
new file mode 100644
index 000000000000..3b2e4ed91c2f
--- /dev/null
+++ b/src/diffusers/hooks/hooks.py
@@ -0,0 +1,236 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+
+from ..utils.logging import get_logger
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ModelHook:
+ r"""
+ A hook that contains callbacks to be executed just before and after the forward method of a model.
+ """
+
+ _is_stateful = False
+
+ def __init__(self):
+ self.fn_ref: "HookFunctionReference" = None
+
+ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
+ r"""
+ Hook that is executed when a model is initialized.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module attached to this hook.
+ """
+ return module
+
+ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
+ r"""
+ Hook that is executed when a model is deinitalized.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module attached to this hook.
+ """
+ return module
+
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
+ r"""
+ Hook that is executed just before the forward method of the model.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module whose forward pass will be executed just after this event.
+ args (`Tuple[Any]`):
+ The positional arguments passed to the module.
+ kwargs (`Dict[Str, Any]`):
+ The keyword arguments passed to the module.
+ Returns:
+ `Tuple[Tuple[Any], Dict[Str, Any]]`:
+ A tuple with the treated `args` and `kwargs`.
+ """
+ return args, kwargs
+
+ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
+ r"""
+ Hook that is executed just after the forward method of the model.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module whose forward pass been executed just before this event.
+ output (`Any`):
+ The output of the module.
+ Returns:
+ `Any`: The processed `output`.
+ """
+ return output
+
+ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
+ r"""
+ Hook that is executed when the hook is detached from a module.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module detached from this hook.
+ """
+ return module
+
+ def reset_state(self, module: torch.nn.Module):
+ if self._is_stateful:
+ raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
+ return module
+
+
+class HookFunctionReference:
+ def __init__(self) -> None:
+ """A container class that maintains mutable references to forward pass functions in a hook chain.
+
+ Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
+ entire forward pass structure.
+
+ Attributes:
+ pre_forward: A callable that processes inputs before the main forward pass.
+ post_forward: A callable that processes outputs after the main forward pass.
+ forward: The current forward function in the hook chain.
+ original_forward: The original forward function, stored when a hook provides a custom new_forward.
+
+ The class enables hook removal by allowing updates to the forward chain through reference modification rather
+ than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
+ be updated, preserving the execution order of the remaining hooks.
+ """
+ self.pre_forward = None
+ self.post_forward = None
+ self.forward = None
+ self.original_forward = None
+
+
+class HookRegistry:
+ def __init__(self, module_ref: torch.nn.Module) -> None:
+ super().__init__()
+
+ self.hooks: Dict[str, ModelHook] = {}
+
+ self._module_ref = module_ref
+ self._hook_order = []
+ self._fn_refs = []
+
+ def register_hook(self, hook: ModelHook, name: str) -> None:
+ if name in self.hooks.keys():
+ raise ValueError(
+ f"Hook with name {name} already exists in the registry. Please use a different name or "
+ f"first remove the existing hook and then add a new one."
+ )
+
+ self._module_ref = hook.initialize_hook(self._module_ref)
+
+ def create_new_forward(function_reference: HookFunctionReference):
+ def new_forward(module, *args, **kwargs):
+ args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
+ output = function_reference.forward(*args, **kwargs)
+ return function_reference.post_forward(module, output)
+
+ return new_forward
+
+ forward = self._module_ref.forward
+
+ fn_ref = HookFunctionReference()
+ fn_ref.pre_forward = hook.pre_forward
+ fn_ref.post_forward = hook.post_forward
+ fn_ref.forward = forward
+
+ if hasattr(hook, "new_forward"):
+ fn_ref.original_forward = forward
+ fn_ref.forward = functools.update_wrapper(
+ functools.partial(hook.new_forward, self._module_ref), hook.new_forward
+ )
+
+ rewritten_forward = create_new_forward(fn_ref)
+ self._module_ref.forward = functools.update_wrapper(
+ functools.partial(rewritten_forward, self._module_ref), rewritten_forward
+ )
+
+ hook.fn_ref = fn_ref
+ self.hooks[name] = hook
+ self._hook_order.append(name)
+ self._fn_refs.append(fn_ref)
+
+ def get_hook(self, name: str) -> Optional[ModelHook]:
+ return self.hooks.get(name, None)
+
+ def remove_hook(self, name: str, recurse: bool = True) -> None:
+ if name in self.hooks.keys():
+ num_hooks = len(self._hook_order)
+ hook = self.hooks[name]
+ index = self._hook_order.index(name)
+ fn_ref = self._fn_refs[index]
+
+ old_forward = fn_ref.forward
+ if fn_ref.original_forward is not None:
+ old_forward = fn_ref.original_forward
+
+ if index == num_hooks - 1:
+ self._module_ref.forward = old_forward
+ else:
+ self._fn_refs[index + 1].forward = old_forward
+
+ self._module_ref = hook.deinitalize_hook(self._module_ref)
+ del self.hooks[name]
+ self._hook_order.pop(index)
+ self._fn_refs.pop(index)
+
+ if recurse:
+ for module_name, module in self._module_ref.named_modules():
+ if module_name == "":
+ continue
+ if hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook.remove_hook(name, recurse=False)
+
+ def reset_stateful_hooks(self, recurse: bool = True) -> None:
+ for hook_name in reversed(self._hook_order):
+ hook = self.hooks[hook_name]
+ if hook._is_stateful:
+ hook.reset_state(self._module_ref)
+
+ if recurse:
+ for module_name, module in self._module_ref.named_modules():
+ if module_name == "":
+ continue
+ if hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook.reset_stateful_hooks(recurse=False)
+
+ @classmethod
+ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
+ if not hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook = cls(module)
+ return module._diffusers_hook
+
+ def __repr__(self) -> str:
+ registry_repr = ""
+ for i, hook_name in enumerate(self._hook_order):
+ if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
+ hook_repr = self.hooks[hook_name].__repr__()
+ else:
+ hook_repr = self.hooks[hook_name].__class__.__name__
+ registry_repr += f" ({i}) {hook_name} - {hook_repr}"
+ if i < len(self._hook_order) - 1:
+ registry_repr += "\n"
+ return f"HookRegistry(\n{registry_repr}\n)"
diff --git a/src/diffusers/hooks/layerwise_casting.py b/src/diffusers/hooks/layerwise_casting.py
new file mode 100644
index 000000000000..6f2cfdc3485a
--- /dev/null
+++ b/src/diffusers/hooks/layerwise_casting.py
@@ -0,0 +1,245 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Optional, Tuple, Type, Union
+
+import torch
+
+from ..utils import get_logger, is_peft_available, is_peft_version
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+# fmt: off
+_LAYERWISE_CASTING_HOOK = "layerwise_casting"
+_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
+SUPPORTED_PYTORCH_LAYERS = (
+ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
+ torch.nn.Linear,
+)
+
+DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
+# fmt: on
+
+_SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0")
+if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
+ from peft.helpers import disable_input_dtype_casting
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+
+class LayerwiseCastingHook(ModelHook):
+ r"""
+ A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
+ for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
+ footprint.
+ """
+
+ _is_stateful = False
+
+ def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
+ self.storage_dtype = storage_dtype
+ self.compute_dtype = compute_dtype
+ self.non_blocking = non_blocking
+
+ def initialize_hook(self, module: torch.nn.Module):
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
+ return module
+
+ def deinitalize_hook(self, module: torch.nn.Module):
+ raise NotImplementedError(
+ "LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will "
+ "have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
+ "will lead to precision loss, which might have an impact on the model's generation quality. The model should "
+ "be re-initialized and loaded in the original dtype."
+ )
+
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
+ module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
+ return args, kwargs
+
+ def post_forward(self, module: torch.nn.Module, output):
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
+ return output
+
+
+class PeftInputAutocastDisableHook(ModelHook):
+ r"""
+ A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
+ casts the inputs to the weight dtype of the module, which can lead to precision loss.
+
+ The reasons for needing this are:
+ - If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
+ inputs will be casted to the, possibly lower precision, storage dtype. Reference:
+ https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
+ - We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
+ that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
+ hoping to achieve:
+ 1. Making forward implementations independent of device/dtype casting operations as much as possible.
+ 2. Peforming inference without losing information from casting to different precisions. With the current
+ PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
+ with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
+ torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
+ forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
+ LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
+ """
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ with disable_input_dtype_casting(module):
+ return self.fn_ref.original_forward(*args, **kwargs)
+
+
+def apply_layerwise_casting(
+ module: torch.nn.Module,
+ storage_dtype: torch.dtype,
+ compute_dtype: torch.dtype,
+ skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
+ skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
+ non_blocking: bool = False,
+) -> None:
+ r"""
+ Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
+ nn.Module using diffusers layers or pytorch primitives.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import CogVideoXTransformer3DModel
+
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+
+ >>> apply_layerwise_casting(
+ ... transformer,
+ ... storage_dtype=torch.float8_e4m3fn,
+ ... compute_dtype=torch.bfloat16,
+ ... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
+ ... non_blocking=True,
+ ... )
+ ```
+
+ Args:
+ module (`torch.nn.Module`):
+ The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
+ precision dtype for storage.
+ storage_dtype (`torch.dtype`):
+ The dtype to cast the module to before/after the forward pass for storage.
+ compute_dtype (`torch.dtype`):
+ The dtype to cast the module to during the forward pass for computation.
+ skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
+ A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
+ to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
+ alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
+ instead of its internal submodules.
+ skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
+ A list of module classes to skip during the layerwise casting process.
+ non_blocking (`bool`, defaults to `False`):
+ If `True`, the weight casting operations are non-blocking.
+ """
+ if skip_modules_pattern == "auto":
+ skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
+
+ if skip_modules_classes is None and skip_modules_pattern is None:
+ apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
+ return
+
+ _apply_layerwise_casting(
+ module,
+ storage_dtype,
+ compute_dtype,
+ skip_modules_pattern,
+ skip_modules_classes,
+ non_blocking,
+ )
+ _disable_peft_input_autocast(module)
+
+
+def _apply_layerwise_casting(
+ module: torch.nn.Module,
+ storage_dtype: torch.dtype,
+ compute_dtype: torch.dtype,
+ skip_modules_pattern: Optional[Tuple[str, ...]] = None,
+ skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
+ non_blocking: bool = False,
+ _prefix: str = "",
+) -> None:
+ should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
+ skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
+ )
+ if should_skip:
+ logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
+ return
+
+ if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
+ logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
+ apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
+ return
+
+ for name, submodule in module.named_children():
+ layer_name = f"{_prefix}.{name}" if _prefix else name
+ _apply_layerwise_casting(
+ submodule,
+ storage_dtype,
+ compute_dtype,
+ skip_modules_pattern,
+ skip_modules_classes,
+ non_blocking,
+ _prefix=layer_name,
+ )
+
+
+def apply_layerwise_casting_hook(
+ module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
+) -> None:
+ r"""
+ Applies a `LayerwiseCastingHook` to a given module.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module to attach the hook to.
+ storage_dtype (`torch.dtype`):
+ The dtype to cast the module to before the forward pass.
+ compute_dtype (`torch.dtype`):
+ The dtype to cast the module to during the forward pass.
+ non_blocking (`bool`):
+ If `True`, the weight casting operations are non-blocking.
+ """
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
+ registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
+
+
+def _is_layerwise_casting_active(module: torch.nn.Module) -> bool:
+ for submodule in module.modules():
+ if (
+ hasattr(submodule, "_diffusers_hook")
+ and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None
+ ):
+ return True
+ return False
+
+
+def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
+ if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
+ return
+ for submodule in module.modules():
+ if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule):
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
+ hook = PeftInputAutocastDisableHook()
+ registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py
new file mode 100644
index 000000000000..5d50f4b816c1
--- /dev/null
+++ b/src/diffusers/hooks/pyramid_attention_broadcast.py
@@ -0,0 +1,311 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Tuple, Union
+
+import torch
+
+from ..models.attention_processor import Attention, MochiAttention
+from ..utils import logging
+from .hooks import HookRegistry, ModelHook
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
+_ATTENTION_CLASSES = (Attention, MochiAttention)
+_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
+_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
+_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
+
+
+@dataclass
+class PyramidAttentionBroadcastConfig:
+ r"""
+ Configuration for Pyramid Attention Broadcast.
+
+ Args:
+ spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
+ The number of times a specific spatial attention broadcast is skipped before computing the attention states
+ to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
+ old attention states will be re-used) before computing the new attention states again.
+ temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
+ The number of times a specific temporal attention broadcast is skipped before computing the attention
+ states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
+ (i.e., old attention states will be re-used) before computing the new attention states again.
+ cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
+ The number of times a specific cross-attention broadcast is skipped before computing the attention states
+ to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
+ old attention states will be re-used) before computing the new attention states again.
+ spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
+ The range of timesteps to skip in the spatial attention layer. The attention computations will be
+ conditionally skipped if the current timestep is within the specified range.
+ temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
+ The range of timesteps to skip in the temporal attention layer. The attention computations will be
+ conditionally skipped if the current timestep is within the specified range.
+ cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
+ The range of timesteps to skip in the cross-attention layer. The attention computations will be
+ conditionally skipped if the current timestep is within the specified range.
+ spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
+ The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
+ temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
+ The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
+ cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
+ The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
+ """
+
+ spatial_attention_block_skip_range: Optional[int] = None
+ temporal_attention_block_skip_range: Optional[int] = None
+ cross_attention_block_skip_range: Optional[int] = None
+
+ spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
+ temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
+ cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
+
+ spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
+ temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
+ cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
+
+ current_timestep_callback: Callable[[], int] = None
+
+ # TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
+ # so not added for now)
+
+ def __repr__(self) -> str:
+ return (
+ f"PyramidAttentionBroadcastConfig(\n"
+ f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
+ f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
+ f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
+ f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
+ f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
+ f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
+ f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
+ f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
+ f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
+ f" current_timestep_callback={self.current_timestep_callback}\n"
+ ")"
+ )
+
+
+class PyramidAttentionBroadcastState:
+ r"""
+ State for Pyramid Attention Broadcast.
+
+ Attributes:
+ iteration (`int`):
+ The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
+ called before starting a new inference forward pass for PAB to work correctly.
+ cache (`Any`):
+ The cached output from the previous forward pass. This is used to re-use the attention states when the
+ attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
+ """
+
+ def __init__(self) -> None:
+ self.iteration = 0
+ self.cache = None
+
+ def reset(self):
+ self.iteration = 0
+ self.cache = None
+
+ def __repr__(self):
+ cache_repr = ""
+ if self.cache is None:
+ cache_repr = "None"
+ else:
+ cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
+ return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
+
+
+class PyramidAttentionBroadcastHook(ModelHook):
+ r"""A hook that applies Pyramid Attention Broadcast to a given module."""
+
+ _is_stateful = True
+
+ def __init__(
+ self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
+ ) -> None:
+ super().__init__()
+
+ self.timestep_skip_range = timestep_skip_range
+ self.block_skip_range = block_skip_range
+ self.current_timestep_callback = current_timestep_callback
+
+ def initialize_hook(self, module):
+ self.state = PyramidAttentionBroadcastState()
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
+ is_within_timestep_range = (
+ self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
+ )
+ should_compute_attention = (
+ self.state.cache is None
+ or self.state.iteration == 0
+ or not is_within_timestep_range
+ or self.state.iteration % self.block_skip_range == 0
+ )
+
+ if should_compute_attention:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ else:
+ output = self.state.cache
+
+ self.state.cache = output
+ self.state.iteration += 1
+ return output
+
+ def reset_state(self, module: torch.nn.Module) -> None:
+ self.state.reset()
+ return module
+
+
+def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
+ r"""
+ Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
+
+ PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
+ reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
+ similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
+ spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
+ than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module to apply Pyramid Attention Broadcast to.
+ config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
+ The configuration to use for Pyramid Attention Broadcast.
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
+ >>> from diffusers.utils import export_to_video
+
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> config = PyramidAttentionBroadcastConfig(
+ ... spatial_attention_block_skip_range=2,
+ ... spatial_attention_timestep_skip_range=(100, 800),
+ ... current_timestep_callback=lambda: pipe.current_timestep,
+ ... )
+ >>> apply_pyramid_attention_broadcast(pipe.transformer, config)
+ ```
+ """
+ if config.current_timestep_callback is None:
+ raise ValueError(
+ "The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
+ )
+
+ if (
+ config.spatial_attention_block_skip_range is None
+ and config.temporal_attention_block_skip_range is None
+ and config.cross_attention_block_skip_range is None
+ ):
+ logger.warning(
+ "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
+ "or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
+ "To avoid this warning, please set one of the above parameters."
+ )
+ config.spatial_attention_block_skip_range = 2
+
+ for name, submodule in module.named_modules():
+ if not isinstance(submodule, _ATTENTION_CLASSES):
+ # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
+ # cannot be applied to this layer. For custom layers, users can extend this functionality and implement
+ # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
+ continue
+ _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
+
+
+def _apply_pyramid_attention_broadcast_on_attention_class(
+ name: str, module: Attention, config: PyramidAttentionBroadcastConfig
+) -> bool:
+ is_spatial_self_attention = (
+ any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
+ and config.spatial_attention_block_skip_range is not None
+ and not getattr(module, "is_cross_attention", False)
+ )
+ is_temporal_self_attention = (
+ any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
+ and config.temporal_attention_block_skip_range is not None
+ and not getattr(module, "is_cross_attention", False)
+ )
+ is_cross_attention = (
+ any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
+ and config.cross_attention_block_skip_range is not None
+ and getattr(module, "is_cross_attention", False)
+ )
+
+ block_skip_range, timestep_skip_range, block_type = None, None, None
+ if is_spatial_self_attention:
+ block_skip_range = config.spatial_attention_block_skip_range
+ timestep_skip_range = config.spatial_attention_timestep_skip_range
+ block_type = "spatial"
+ elif is_temporal_self_attention:
+ block_skip_range = config.temporal_attention_block_skip_range
+ timestep_skip_range = config.temporal_attention_timestep_skip_range
+ block_type = "temporal"
+ elif is_cross_attention:
+ block_skip_range = config.cross_attention_block_skip_range
+ timestep_skip_range = config.cross_attention_timestep_skip_range
+ block_type = "cross"
+
+ if block_skip_range is None or timestep_skip_range is None:
+ logger.info(
+ f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
+ f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
+ f"however, that this layer may still be valid for applying PAB. Please specify the correct "
+ f"block identifiers in the configuration."
+ )
+ return False
+
+ logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
+ _apply_pyramid_attention_broadcast_hook(
+ module, timestep_skip_range, block_skip_range, config.current_timestep_callback
+ )
+ return True
+
+
+def _apply_pyramid_attention_broadcast_hook(
+ module: Union[Attention, MochiAttention],
+ timestep_skip_range: Tuple[int, int],
+ block_skip_range: int,
+ current_timestep_callback: Callable[[], int],
+):
+ r"""
+ Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module to apply Pyramid Attention Broadcast to.
+ timestep_skip_range (`Tuple[int, int]`):
+ The range of timesteps to skip in the attention layer. The attention computations will be conditionally
+ skipped if the current timestep is within the specified range.
+ block_skip_range (`int`):
+ The number of times a specific attention broadcast is skipped before computing the attention states to
+ re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
+ attention states will be re-used) before computing the new attention states again.
+ current_timestep_callback (`Callable[[], int]`):
+ A callback function that returns the current inference timestep.
+ """
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
+ registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index 0fffe67b0bdb..d6913f045ad2 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to
`np.ndarray` or `torch.Tensor`:
The denormalized image array.
"""
- return (images / 2 + 0.5).clamp(0, 1)
+ return (images * 0.5 + 0.5).clamp(0, 1)
@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
@@ -537,6 +537,26 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
return image
+ def _denormalize_conditionally(
+ self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
+ ) -> torch.Tensor:
+ r"""
+ Denormalize a batch of images based on a condition list.
+
+ Args:
+ images (`torch.Tensor`):
+ The input image tensor.
+ do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
+ A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
+ value of `do_normalize` in the `VaeImageProcessor` config.
+ """
+ if do_denormalize is None:
+ return self.denormalize(images) if self.config.do_normalize else images
+
+ return torch.stack(
+ [self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
+ )
+
def get_default_height_width(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
@@ -752,12 +772,7 @@ def postprocess(
if output_type == "latent":
return image
- if do_denormalize is None:
- do_denormalize = [self.config.do_normalize] * image.shape[0]
-
- image = torch.stack(
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
- )
+ image = self._denormalize_conditionally(image, do_denormalize)
if output_type == "pt":
return image
@@ -795,13 +810,11 @@ def apply_overlay(
The final image with the overlay applied.
"""
- width, height = image.width, image.height
-
- init_image = self.resize(init_image, width=width, height=height)
- mask = self.resize(mask, width=width, height=height)
+ width, height = init_image.width, init_image.height
init_image_masked = PIL.Image.new("RGBa", (width, height))
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
+
init_image_masked = init_image_masked.convert("RGBA")
if crop_coords is not None:
@@ -968,12 +981,7 @@ def postprocess(
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
- if do_denormalize is None:
- do_denormalize = [self.config.do_normalize] * image.shape[0]
-
- image = torch.stack(
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
- )
+ image = self._denormalize_conditionally(image, do_denormalize)
image = self.pt_to_numpy(image)
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index bf7212216845..3ba1bfacf3dd 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -55,7 +55,8 @@ def text_encoder_attn_modules(text_encoder):
if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
-
+ _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
+ _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
@@ -65,12 +66,23 @@ def text_encoder_attn_modules(text_encoder):
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
+ "LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
"CogVideoXLoraLoaderMixin",
+ "CogView4LoraLoaderMixin",
+ "Mochi1LoraLoaderMixin",
+ "HunyuanVideoLoraLoaderMixin",
+ "SanaLoraLoaderMixin",
+ "Lumina2LoraLoaderMixin",
+ "WanLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
- _import_structure["ip_adapter"] = ["IPAdapterMixin"]
+ _import_structure["ip_adapter"] = [
+ "IPAdapterMixin",
+ "FluxIPAdapterMixin",
+ "SD3IPAdapterMixin",
+ ]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -78,19 +90,32 @@ def text_encoder_attn_modules(text_encoder):
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
+ from .transformer_flux import FluxTransformer2DLoadersMixin
+ from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
- from .ip_adapter import IPAdapterMixin
+ from .ip_adapter import (
+ FluxIPAdapterMixin,
+ IPAdapterMixin,
+ SD3IPAdapterMixin,
+ )
from .lora_pipeline import (
AmusedLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
+ CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
+ HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
+ LTXVideoLoraLoaderMixin,
+ Lumina2LoraLoaderMixin,
+ Mochi1LoraLoaderMixin,
+ SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
+ WanLoraLoaderMixin,
)
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index 1006dab9e4b9..21a1a70ff79b 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -23,7 +23,9 @@
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
+ _get_detailed_type,
_get_model_file,
+ _is_valid_type,
is_accelerate_available,
is_torch_version,
is_transformers_available,
@@ -33,17 +35,20 @@
if is_transformers_available():
- from transformers import (
- CLIPImageProcessor,
- CLIPVisionModelWithProjection,
- )
-
- from ..models.attention_processor import (
- AttnProcessor,
- AttnProcessor2_0,
- IPAdapterAttnProcessor,
- IPAdapterAttnProcessor2_0,
- )
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
+
+from ..models.attention_processor import (
+ AttnProcessor,
+ AttnProcessor2_0,
+ FluxAttnProcessor2_0,
+ FluxIPAdapterJointAttnProcessor2_0,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+ IPAdapterXFormersAttnProcessor,
+ JointAttnProcessor2_0,
+ SD3IPAdapterJointAttnProcessor2_0,
+)
+
logger = logging.get_logger(__name__)
@@ -76,7 +81,7 @@ def load_ip_adapter(
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
- `weight_name`.
+ `subfolder`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
@@ -189,7 +194,7 @@ def load_ip_adapter(
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
- if keys != ["image_proj", "ip_adapter"]:
+ if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
@@ -210,7 +215,8 @@ def load_ip_adapter(
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
- ).to(self.device, dtype=self.dtype)
+ torch_dtype=self.dtype,
+ ).to(self.device)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
@@ -284,7 +290,9 @@ def set_ip_adapter_scale(self, scale):
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
for attn_name, attn_processor in unet.attn_processors.items():
- if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
+ if isinstance(
+ attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
+ ):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
@@ -342,7 +350,531 @@ def unload_ip_adapter(self):
)
attn_procs[name] = (
attn_processor_class
- if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
+ if isinstance(
+ value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
+ )
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
+
+
+class FluxIPAdapterMixin:
+ """Mixin for handling Flux IP Adapters."""
+
+ @validate_hf_hub_args
+ def load_ip_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
+ weight_name: Union[str, List[str]],
+ subfolder: Optional[Union[str, List[str]]] = "",
+ image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
+ image_encoder_subfolder: Optional[str] = "",
+ image_encoder_dtype: torch.dtype = torch.float16,
+ **kwargs,
+ ):
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ subfolder (`str` or `List[str]`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
+ weight_name (`str` or `List[str]`):
+ The name of the weight file to load. If a list is passed, it should have the same length as
+ `weight_name`.
+ image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
+ Can be either:
+
+ - A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
+ hosted on the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+
+ # handle the list inputs for multiple IP Adapters
+ if not isinstance(weight_name, list):
+ weight_name = [weight_name]
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
+ if len(pretrained_model_name_or_path_or_dict) == 1:
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
+
+ if not isinstance(subfolder, list):
+ subfolder = [subfolder]
+ if len(subfolder) == 1:
+ subfolder = subfolder * len(weight_name)
+
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
+
+ if len(weight_name) != len(subfolder):
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
+
+ # Load the main state dict first.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+ state_dicts = []
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
+ ):
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
+ ip_adapter_keys = ["double_blocks.", "ip_adapter."]
+ for key in f.keys():
+ if any(key.startswith(prefix) for prefix in image_proj_keys):
+ diffusers_name = ".".join(key.split(".")[1:])
+ state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
+ elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
+ diffusers_name = (
+ ".".join(key.split(".")[1:])
+ .replace("ip_adapter_double_stream_k_proj", "to_k_ip")
+ .replace("ip_adapter_double_stream_v_proj", "to_v_ip")
+ .replace("processor.", "")
+ )
+ state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
+ else:
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if keys != ["image_proj", "ip_adapter"]:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+ state_dicts.append(state_dict)
+
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+ if image_encoder_pretrained_model_name_or_path is not None:
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
+ image_encoder = (
+ CLIPVisionModelWithProjection.from_pretrained(
+ image_encoder_pretrained_model_name_or_path,
+ subfolder=image_encoder_subfolder,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ dtype=image_encoder_dtype,
+ )
+ .to(self.device)
+ .eval()
+ )
+ self.register_modules(image_encoder=image_encoder)
+ else:
+ raise ValueError(
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
+ )
+ else:
+ logger.warning(
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
+ )
+
+ # create feature extractor if it has not been registered to the pipeline yet
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
+ # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
+ default_clip_size = 224
+ clip_image_size = (
+ self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
+ )
+ feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
+ self.register_modules(feature_extractor=feature_extractor)
+
+ # load ip-adapter into transformer
+ self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
+ """
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
+ granular control over each IP-Adapter behavior. A config can be a float or a list.
+
+ `float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
+ length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
+ number of IP adapters and each must match the number of blocks.
+
+ Example:
+
+ ```py
+ # To use original IP-Adapter
+ scale = 1.0
+ pipeline.set_ip_adapter_scale(scale)
+
+
+ def LinearStrengthModel(start, finish, size):
+ return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
+
+
+ ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
+ pipeline.set_ip_adapter_scale(ip_strengths)
+ ```
+ """
+
+ scale_type = Union[int, float]
+ num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
+ num_layers = self.transformer.config.num_layers
+
+ # Single value for all layers of all IP-Adapters
+ if isinstance(scale, scale_type):
+ scale = [scale for _ in range(num_ip_adapters)]
+ # List of per-layer scales for a single IP-Adapter
+ elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
+ scale = [scale]
+ # Invalid scale type
+ elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
+ raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
+
+ if len(scale) != num_ip_adapters:
+ raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
+
+ if any(len(s) != num_layers for s in scale if isinstance(s, list)):
+ invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
+ raise ValueError(
+ f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
+ )
+
+ # Scalars are transformed to lists with length num_layers
+ scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
+
+ # Set scales. zip over scale_configs prevents going into single transformer layers
+ for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
+ attn_processor.scale = scale
+
+ def unload_ip_adapter(self):
+ """
+ Unloads the IP Adapter weights
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.unload_ip_adapter()
+ >>> ...
+ ```
+ """
+ # remove CLIP image encoder
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
+ self.image_encoder = None
+ self.register_to_config(image_encoder=[None, None])
+
+ # remove feature extractor only when safety_checker is None as safety_checker uses
+ # the feature_extractor later
+ if not hasattr(self, "safety_checker"):
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
+ self.feature_extractor = None
+ self.register_to_config(feature_extractor=[None, None])
+
+ # remove hidden encoder
+ self.transformer.encoder_hid_proj = None
+ self.transformer.config.encoder_hid_dim_type = None
+
+ # restore original Transformer attention processors layers
+ attn_procs = {}
+ for name, value in self.transformer.attn_processors.items():
+ attn_processor_class = FluxAttnProcessor2_0()
+ attn_procs[name] = (
+ attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
+ )
+ self.transformer.set_attn_processor(attn_procs)
+
+
+class SD3IPAdapterMixin:
+ """Mixin for handling StableDiffusion 3 IP Adapters."""
+
+ @property
+ def is_ip_adapter_active(self) -> bool:
+ """Checks if IP-Adapter is loaded and scale > 0.
+
+ IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
+ the image context is irrelevant.
+
+ Returns:
+ `bool`: True when IP-Adapter is loaded and any layer has scale > 0.
+ """
+ scales = [
+ attn_proc.scale
+ for attn_proc in self.transformer.attn_processors.values()
+ if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
+ ]
+
+ return len(scales) > 0 and any(scale > 0 for scale in scales)
+
+ @validate_hf_hub_args
+ def load_ip_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ weight_name: str = "ip-adapter.safetensors",
+ subfolder: Optional[str] = None,
+ image_encoder_folder: Optional[str] = "image_encoder",
+ **kwargs,
+ ) -> None:
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ weight_name (`str`, defaults to "ip-adapter.safetensors"):
+ The name of the weight file to load. If a list is passed, it should have the same length as
+ `subfolder`.
+ subfolder (`str`, *optional*):
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
+ `image_encoder_folder="different_subfolder/image_encoder"`.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+ # Load the main state dict first
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if "image_proj" not in keys and "ip_adapter" not in keys:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+ # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+ if image_encoder_folder is not None:
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
+ if image_encoder_folder.count("/") == 0:
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
+ else:
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
+
+ # Commons args for loading image encoder and image processor
+ kwargs = {
+ "low_cpu_mem_usage": low_cpu_mem_usage,
+ "cache_dir": cache_dir,
+ "local_files_only": local_files_only,
+ }
+
+ self.register_modules(
+ feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
+ image_encoder=SiglipVisionModel.from_pretrained(
+ image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
+ ).to(self.device),
+ )
+ else:
+ raise ValueError(
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
+ )
+ else:
+ logger.warning(
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
+ )
+
+ # Load IP-Adapter into transformer
+ self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ def set_ip_adapter_scale(self, scale: float) -> None:
+ """
+ Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
+ conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
+ the model to produce more diverse images, but they may not be as aligned with the image prompt.
+
+ Example:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.set_ip_adapter_scale(0.6)
+ >>> ...
+ ```
+
+ Args:
+ scale (float):
+ IP-Adapter scale to be set.
+
+ """
+ for attn_processor in self.transformer.attn_processors.values():
+ if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
+ attn_processor.scale = scale
+
+ def unload_ip_adapter(self) -> None:
+ """
+ Unloads the IP Adapter weights.
+
+ Example:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.unload_ip_adapter()
+ >>> ...
+ ```
+ """
+ # Remove image encoder
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
+ self.image_encoder = None
+ self.register_to_config(image_encoder=None)
+
+ # Remove feature extractor
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
+ self.feature_extractor = None
+ self.register_to_config(feature_extractor=None)
+
+ # Remove image projection
+ self.transformer.image_proj = None
+
+ # Restore original attention processors layers
+ attn_procs = {
+ name: (
+ JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
+ )
+ for name, value in self.transformer.attn_processors.items()
+ }
+ self.transformer.set_attn_processor(attn_procs)
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
index e124b6eeacf3..17ed8c5444fc 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -28,13 +28,20 @@
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_peft,
delete_adapter_layers,
deprecate,
+ get_adapter_name,
+ get_peft_kwargs,
is_accelerate_available,
is_peft_available,
+ is_peft_version,
is_transformers_available,
+ is_transformers_version,
logging,
recurse_remove_peft_layers,
+ scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
@@ -43,6 +50,8 @@
if is_transformers_available():
from transformers import PreTrainedModel
+ from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
+
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -51,6 +60,9 @@
logger = logging.get_logger(__name__)
+LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
+LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
"""
@@ -181,6 +193,269 @@ def _remove_text_encoder_monkey_patch(text_encoder):
text_encoder._hf_peft_config_loaded = None
+def _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ weight_name,
+ use_safetensors,
+ local_files_only,
+ cache_dir,
+ force_download,
+ proxies,
+ token,
+ revision,
+ subfolder,
+ user_agent,
+ allow_pickle,
+):
+ model_file = None
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ # Here we're relaxing the loading check to enable more Inference API
+ # friendliness where sometimes, it's not at all possible to automatically
+ # determine `weight_name`.
+ if weight_name is None:
+ weight_name = _best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict,
+ file_extension=".safetensors",
+ local_files_only=local_files_only,
+ )
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except (IOError, safetensors.SafetensorError) as e:
+ if not allow_pickle:
+ raise e
+ # try loading non-safetensors weights
+ model_file = None
+ pass
+
+ if model_file is None:
+ if weight_name is None:
+ weight_name = _best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
+ )
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ return state_dict
+
+
+def _best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
+):
+ if local_files_only or HF_HUB_OFFLINE:
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
+
+ targeted_files = []
+
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
+ return
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
+ targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
+ else:
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
+ if len(targeted_files) == 0:
+ return
+
+ # "scheduler" does not correspond to a LoRA checkpoint.
+ # "optimizer" does not correspond to a LoRA checkpoint
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
+ targeted_files = list(
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
+ )
+
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
+
+ if len(targeted_files) > 1:
+ raise ValueError(
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
+ )
+ weight_name = targeted_files[0]
+ return weight_name
+
+
+def _load_lora_into_text_encoder(
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ text_encoder_name="text_encoder",
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ peft_kwargs = {}
+ if low_cpu_mem_usage:
+ if not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+ if not is_transformers_version(">", "4.45.2"):
+ # Note from sayakpaul: It's not in `transformers` stable yet.
+ # https://github.com/huggingface/transformers/pull/33725/
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
+ )
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
+
+ from peft import LoraConfig
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
+ # their prefixes.
+ prefix = text_encoder_name if prefix is None else prefix
+
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ if prefix is not None:
+ state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
+
+ if len(state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ state_dict = convert_state_dict_to_diffusers(state_dict)
+
+ # convert state dict
+ state_dict = convert_state_dict_to_peft(state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in state_dict:
+ continue
+ rank[rank_key] = state_dict[rank_key].shape[1]
+
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ for module in ("fc1", "fc2"):
+ rank_key = f"{name}.{module}.lora_B.weight"
+ if rank_key not in state_dict:
+ continue
+ rank[rank_key] = state_dict[rank_key].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
+
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+
+ if "lora_bias" in lora_config_kwargs:
+ if lora_config_kwargs["lora_bias"]:
+ if is_peft_version("<=", "0.13.2"):
+ raise ValueError(
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<=", "0.13.2"):
+ lora_config_kwargs.pop("lora_bias")
+
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=state_dict,
+ peft_config=lora_config,
+ **peft_kwargs,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ if prefix is not None and not state_dict:
+ logger.warning(
+ f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
+ "This is safe to ignore if LoRA state dict didn't originally have any "
+ f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
+ "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
+ "https://github.com/huggingface/diffusers/issues/new"
+ )
+
+
+def _func_optionally_disable_offloading(_pipeline):
+ is_model_cpu_offload = False
+ is_sequential_cpu_offload = False
+
+ if _pipeline is not None and _pipeline.hf_device_map is None:
+ for _, component in _pipeline.components.items():
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
+ if not is_model_cpu_offload:
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
+ if not is_sequential_cpu_offload:
+ is_sequential_cpu_offload = (
+ isinstance(component._hf_hook, AlignDevicesHook)
+ or hasattr(component._hf_hook, "hooks")
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
+ )
+
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+ )
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
+
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
+
+
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
@@ -211,147 +486,19 @@ def _optionally_disable_offloading(cls, _pipeline):
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
- is_model_cpu_offload = False
- is_sequential_cpu_offload = False
-
- if _pipeline is not None and _pipeline.hf_device_map is None:
- for _, component in _pipeline.components.items():
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
- if not is_model_cpu_offload:
- is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
- if not is_sequential_cpu_offload:
- is_sequential_cpu_offload = (
- isinstance(component._hf_hook, AlignDevicesHook)
- or hasattr(component._hf_hook, "hooks")
- and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
- )
-
- logger.info(
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
- )
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
-
- return (is_model_cpu_offload, is_sequential_cpu_offload)
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
@classmethod
- def _fetch_state_dict(
- cls,
- pretrained_model_name_or_path_or_dict,
- weight_name,
- use_safetensors,
- local_files_only,
- cache_dir,
- force_download,
- proxies,
- token,
- revision,
- subfolder,
- user_agent,
- allow_pickle,
- ):
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
-
- model_file = None
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- # Let's first try to load .safetensors weights
- if (use_safetensors and weight_name is None) or (
- weight_name is not None and weight_name.endswith(".safetensors")
- ):
- try:
- # Here we're relaxing the loading check to enable more Inference API
- # friendliness where sometimes, it's not at all possible to automatically
- # determine `weight_name`.
- if weight_name is None:
- weight_name = cls._best_guess_weight_name(
- pretrained_model_name_or_path_or_dict,
- file_extension=".safetensors",
- local_files_only=local_files_only,
- )
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
- except (IOError, safetensors.SafetensorError) as e:
- if not allow_pickle:
- raise e
- # try loading non-safetensors weights
- model_file = None
- pass
-
- if model_file is None:
- if weight_name is None:
- weight_name = cls._best_guess_weight_name(
- pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
- )
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name or LORA_WEIGHT_NAME,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- state_dict = load_state_dict(model_file)
- else:
- state_dict = pretrained_model_name_or_path_or_dict
-
- return state_dict
+ def _fetch_state_dict(cls, *args, **kwargs):
+ deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
+ deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
+ return _fetch_state_dict(*args, **kwargs)
@classmethod
- def _best_guess_weight_name(
- cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
- ):
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
-
- if local_files_only or HF_HUB_OFFLINE:
- raise ValueError("When using the offline mode, you must specify a `weight_name`.")
-
- targeted_files = []
-
- if os.path.isfile(pretrained_model_name_or_path_or_dict):
- return
- elif os.path.isdir(pretrained_model_name_or_path_or_dict):
- targeted_files = [
- f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
- ]
- else:
- files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
- targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
- if len(targeted_files) == 0:
- return
-
- # "scheduler" does not correspond to a LoRA checkpoint.
- # "optimizer" does not correspond to a LoRA checkpoint
- # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
- unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
- targeted_files = list(
- filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
- )
-
- if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
- targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
- elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
- targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
-
- if len(targeted_files) > 1:
- raise ValueError(
- f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
- )
- weight_name = targeted_files[0]
- return weight_name
+ def _best_guess_weight_name(cls, *args, **kwargs):
+ deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
+ deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
+ return _best_guess_weight_name(*args, **kwargs)
def unload_lora_weights(self):
"""
@@ -518,8 +665,20 @@ def set_adapters(
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
- adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+ if isinstance(adapter_weights, dict):
+ components_passed = set(adapter_weights.keys())
+ lora_components = set(self._lora_loadable_modules)
+
+ invalid_components = sorted(components_passed - lora_components)
+ if invalid_components:
+ logger.warning(
+ f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
+ f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
+ "to the invalid components will be removed and ignored."
+ )
+ adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
adapter_weights = copy.deepcopy(adapter_weights)
# Expand weights into a list, one entry per adapter
@@ -554,12 +713,6 @@ def set_adapters(
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
component_adapter_weights = weights.pop(component, None)
-
- if component_adapter_weights is not None and not hasattr(self, component):
- logger.warning(
- f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
- )
-
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
logger.warning(
(
@@ -725,8 +878,6 @@ def write_lora_layers(
save_function: Callable,
safe_serialization: bool,
):
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
-
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index d0ca40213b14..20fcb61f3b80 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -519,7 +519,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
- if not all(k.startswith("lora_te1") for k in remaining_keys):
+ if not all(k.startswith("lora_te") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
@@ -558,6 +558,135 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict
+ def _convert_mixture_state_dict_to_diffusers(state_dict):
+ new_state_dict = {}
+
+ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
+ down_key = f"{original_key}.lora_down.weight"
+ down_weight = state_dict.pop(down_key)
+ lora_rank = down_weight.shape[0]
+
+ up_weight_key = f"{original_key}.lora_up.weight"
+ up_weight = state_dict.pop(up_weight_key)
+
+ alpha_key = f"{original_key}.alpha"
+ alpha = state_dict.pop(alpha_key)
+
+ # scale weight by alpha and dim
+ scale = alpha / lora_rank
+ # calculate scale_down and scale_up
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+ down_weight = down_weight * scale_down
+ up_weight = up_weight * scale_up
+
+ diffusers_down_key = f"{diffusers_key}.lora_A.weight"
+ new_state_dict[diffusers_down_key] = down_weight
+ new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
+
+ all_unique_keys = {
+ k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
+ for k in state_dict
+ if not k.startswith(("lora_unet_"))
+ }
+ assert all(k.startswith(("lora_transformer_", "lora_te1_")) for k in all_unique_keys), f"{all_unique_keys=}"
+
+ has_te_keys = False
+ for k in all_unique_keys:
+ if k.startswith("lora_transformer_single_transformer_blocks_"):
+ i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
+ diffusers_key = f"single_transformer_blocks.{i}"
+ elif k.startswith("lora_transformer_transformer_blocks_"):
+ i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
+ diffusers_key = f"transformer_blocks.{i}"
+ elif k.startswith("lora_te1_"):
+ has_te_keys = True
+ continue
+ else:
+ raise NotImplementedError
+
+ if "attn_" in k:
+ if "_to_out_0" in k:
+ diffusers_key += ".attn.to_out.0"
+ elif "_to_add_out" in k:
+ diffusers_key += ".attn.to_add_out"
+ elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
+ remaining = k.split("attn_")[-1]
+ diffusers_key += f".attn.{remaining}"
+ elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
+ remaining = k.split("attn_")[-1]
+ diffusers_key += f".attn.{remaining}"
+
+ _convert(k, diffusers_key, state_dict, new_state_dict)
+
+ if has_te_keys:
+ layer_pattern = re.compile(r"lora_te1_text_model_encoder_layers_(\d+)")
+ attn_mapping = {
+ "q_proj": ".self_attn.q_proj",
+ "k_proj": ".self_attn.k_proj",
+ "v_proj": ".self_attn.v_proj",
+ "out_proj": ".self_attn.out_proj",
+ }
+ mlp_mapping = {"fc1": ".mlp.fc1", "fc2": ".mlp.fc2"}
+ for k in all_unique_keys:
+ if not k.startswith("lora_te1_"):
+ continue
+
+ match = layer_pattern.search(k)
+ if not match:
+ continue
+ i = int(match.group(1))
+ diffusers_key = f"text_model.encoder.layers.{i}"
+
+ if "attn" in k:
+ for key_fragment, suffix in attn_mapping.items():
+ if key_fragment in k:
+ diffusers_key += suffix
+ break
+ elif "mlp" in k:
+ for key_fragment, suffix in mlp_mapping.items():
+ if key_fragment in k:
+ diffusers_key += suffix
+ break
+
+ _convert(k, diffusers_key, state_dict, new_state_dict)
+
+ remaining_all_unet = False
+ if state_dict:
+ remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
+ if remaining_all_unet:
+ keys = list(state_dict.keys())
+ for k in keys:
+ state_dict.pop(k)
+
+ if len(state_dict) > 0:
+ raise ValueError(
+ f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
+ )
+
+ transformer_state_dict = {
+ f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
+ }
+ te_state_dict = {f"text_encoder.{k}": v for k, v in new_state_dict.items() if k.startswith("text_model.")}
+ return {**transformer_state_dict, **te_state_dict}
+
+ # This is weird.
+ # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
+ # has both `peft` and non-peft state dict.
+ has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
+ if has_peft_state_dict:
+ state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
+ return state_dict
+ # Another weird one.
+ has_mixture = any(
+ k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
+ )
+ if has_mixture:
+ return _convert_mixture_state_dict_to_diffusers(state_dict)
+
return _convert_sd_scripts_to_ai_toolkit(state_dict)
@@ -636,10 +765,19 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
new_key = f"transformer.single_transformer_blocks.{block_num}"
- if "proj_lora1" in old_key or "proj_lora2" in old_key:
+ if "proj_lora" in old_key:
new_key += ".proj_out"
- elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
- new_key += ".norm.linear"
+ elif "qkv_lora" in old_key and "up" not in old_key:
+ handle_qkv(
+ old_state_dict,
+ new_state_dict,
+ old_key,
+ [
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
+ ],
+ )
if "down" in old_key:
new_key += ".lora_A.weight"
@@ -658,3 +796,608 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
return new_state_dict
+
+
+def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
+ converted_state_dict = {}
+ original_state_dict_keys = list(original_state_dict.keys())
+ num_layers = 19
+ num_single_layers = 38
+ inner_dim = 3072
+ mlp_ratio = 4.0
+
+ def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+ for lora_key in ["lora_A", "lora_B"]:
+ ## time_text_embed.timestep_embedder <- time_in
+ converted_state_dict[
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
+ if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
+
+ converted_state_dict[
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
+ if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
+
+ ## time_text_embed.text_embedder <- vector_in
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
+ f"vector_in.in_layer.{lora_key}.weight"
+ )
+ if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
+ f"vector_in.in_layer.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
+ f"vector_in.out_layer.{lora_key}.weight"
+ )
+ if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
+ f"vector_in.out_layer.{lora_key}.bias"
+ )
+
+ # guidance
+ has_guidance = any("guidance" in k for k in original_state_dict)
+ if has_guidance:
+ converted_state_dict[
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
+ if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
+
+ converted_state_dict[
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
+ if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
+
+ # context_embedder
+ converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
+ f"txt_in.{lora_key}.weight"
+ )
+ if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
+ f"txt_in.{lora_key}.bias"
+ )
+
+ # x_embedder
+ converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
+ if f"img_in.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norms
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
+ )
+
+ # Q, K, V
+ if lora_key == "lora_A":
+ sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+
+ context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ else:
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
+
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
+
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
+
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
+
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
+ )
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
+ )
+
+ # qk_norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
+ )
+
+ # single transfomer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
+ )
+ if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
+ )
+
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+
+ if lora_key == "lora_A":
+ lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
+
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
+ else:
+ q, k, v, mlp = torch.split(
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
+
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.{lora_key}.weight"
+ )
+ if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.{lora_key}.bias"
+ )
+
+ # qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.key_norm.scale"
+ )
+
+ for lora_key in ["lora_A", "lora_B"]:
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"final_layer.linear.{lora_key}.weight"
+ )
+ if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"final_layer.linear.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
+ )
+ if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
+ )
+
+ if len(original_state_dict) > 0:
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
+def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
+ converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
+
+ def remap_norm_scale_shift_(key, state_dict):
+ weight = state_dict.pop(key)
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
+
+ def remap_txt_in_(key, state_dict):
+ def rename_key(key):
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
+ new_key = new_key.replace("txt_in", "context_embedder")
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
+ new_key = new_key.replace("mlp", "ff")
+ return new_key
+
+ if "self_attn_qkv" in key:
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
+ else:
+ state_dict[rename_key(key)] = state_dict.pop(key)
+
+ def remap_img_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ if "lora_A" in key:
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
+ else:
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
+
+ def remap_txt_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ if "lora_A" in key:
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
+ else:
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
+
+ def remap_single_transformer_blocks_(key, state_dict):
+ hidden_size = 3072
+
+ if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
+ linear1_weight = state_dict.pop(key)
+ if "lora_A" in key:
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
+ ".linear1.lora_A.weight"
+ )
+ state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
+ state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
+ state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
+ state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
+ else:
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
+ ".linear1.lora_B.weight"
+ )
+ state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
+ state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
+ state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
+ state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
+
+ elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
+ linear1_bias = state_dict.pop(key)
+ if "lora_A" in key:
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
+ ".linear1.lora_A.bias"
+ )
+ state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
+ state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
+ state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
+ state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
+ else:
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
+ ".linear1.lora_B.bias"
+ )
+ state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
+ state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
+ state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
+ state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
+
+ else:
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
+ new_key = new_key.replace("linear2", "proj_out")
+ new_key = new_key.replace("q_norm", "attn.norm_q")
+ new_key = new_key.replace("k_norm", "attn.norm_k")
+ state_dict[new_key] = state_dict.pop(key)
+
+ TRANSFORMER_KEYS_RENAME_DICT = {
+ "img_in": "x_embedder",
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
+ "double_blocks": "transformer_blocks",
+ "img_attn_q_norm": "attn.norm_q",
+ "img_attn_k_norm": "attn.norm_k",
+ "img_attn_proj": "attn.to_out.0",
+ "txt_attn_q_norm": "attn.norm_added_q",
+ "txt_attn_k_norm": "attn.norm_added_k",
+ "txt_attn_proj": "attn.to_add_out",
+ "img_mod.linear": "norm1.linear",
+ "img_norm1": "norm1.norm",
+ "img_norm2": "norm2",
+ "img_mlp": "ff",
+ "txt_mod.linear": "norm1_context.linear",
+ "txt_norm1": "norm1.norm",
+ "txt_norm2": "norm2_context",
+ "txt_mlp": "ff_context",
+ "self_attn_proj": "attn.to_out.0",
+ "modulation.linear": "norm.linear",
+ "pre_norm": "norm.norm",
+ "final_layer.norm_final": "norm_out.norm",
+ "final_layer.linear": "proj_out",
+ "fc1": "net.0.proj",
+ "fc2": "net.2",
+ "input_embedder": "proj_in",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "txt_in": remap_txt_in_,
+ "img_attn_qkv": remap_img_attn_qkv_,
+ "txt_attn_qkv": remap_txt_attn_qkv_,
+ "single_blocks": remap_single_transformer_blocks_,
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
+ }
+
+ # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
+ # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
+ # sure that both follow the same initial format by stripping off the "transformer." prefix.
+ for key in list(converted_state_dict.keys()):
+ if key.startswith("transformer."):
+ converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
+ if key.startswith("diffusion_model."):
+ converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
+
+ # Rename and remap the state dict keys
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ # Add back the "transformer." prefix
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
+def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
+ # Remove "diffusion_model." prefix from keys.
+ state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
+ converted_state_dict = {}
+
+ def get_num_layers(keys, pattern):
+ layers = set()
+ for key in keys:
+ match = re.search(pattern, key)
+ if match:
+ layers.add(int(match.group(1)))
+ return len(layers)
+
+ def process_block(prefix, index, convert_norm):
+ # Process attention qkv: pop lora_A and lora_B weights.
+ lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
+ lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
+ for attn_key in ["to_q", "to_k", "to_v"]:
+ converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
+ for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
+ converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
+
+ # Process attention out weights.
+ converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
+ f"{prefix}.{index}.attention.out.lora_A.weight"
+ )
+ converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
+ f"{prefix}.{index}.attention.out.lora_B.weight"
+ )
+
+ # Process feed-forward weights for layers 1, 2, and 3.
+ for layer in range(1, 4):
+ converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
+ f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
+ )
+ converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
+ f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
+ )
+
+ if convert_norm:
+ converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
+ f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
+ )
+ converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
+ f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
+ )
+
+ noise_refiner_pattern = r"noise_refiner\.(\d+)\."
+ num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
+ for i in range(num_noise_refiner_layers):
+ process_block("noise_refiner", i, convert_norm=True)
+
+ context_refiner_pattern = r"context_refiner\.(\d+)\."
+ num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
+ for i in range(num_context_refiner_layers):
+ process_block("context_refiner", i, convert_norm=False)
+
+ core_transformer_pattern = r"layers\.(\d+)\."
+ num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
+ for i in range(num_core_transformer_layers):
+ process_block("layers", i, convert_norm=True)
+
+ if len(state_dict) > 0:
+ raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
+def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
+ converted_state_dict = {}
+ original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
+
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
+ is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
+
+ for i in range(num_blocks):
+ # Self-attention
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
+ f"blocks.{i}.self_attn.{o}.lora_A.weight"
+ )
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
+ f"blocks.{i}.self_attn.{o}.lora_B.weight"
+ )
+
+ # Cross-attention
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
+ f"blocks.{i}.cross_attn.{o}.lora_A.weight"
+ )
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
+ f"blocks.{i}.cross_attn.{o}.lora_B.weight"
+ )
+
+ if is_i2v_lora:
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
+ f"blocks.{i}.cross_attn.{o}.lora_A.weight"
+ )
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
+ f"blocks.{i}.cross_attn.{o}.lora_B.weight"
+ )
+
+ # FFN
+ for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
+ f"blocks.{i}.{o}.lora_A.weight"
+ )
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
+ f"blocks.{i}.{o}.lora_B.weight"
+ )
+
+ if len(original_state_dict) > 0:
+ raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 5e01ec567f9a..e522778deeed 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
from typing import Callable, Dict, List, Optional, Union
@@ -19,24 +20,29 @@
from ..utils import (
USE_PEFT_BACKEND,
- convert_state_dict_to_diffusers,
- convert_state_dict_to_peft,
- convert_unet_state_dict_to_peft,
deprecate,
- get_adapter_name,
- get_peft_kwargs,
+ get_submodule_by_name,
is_peft_available,
is_peft_version,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
- scale_lora_layers,
)
-from .lora_base import LoraBaseMixin
+from .lora_base import ( # noqa
+ LORA_WEIGHT_NAME,
+ LORA_WEIGHT_NAME_SAFE,
+ LoraBaseMixin,
+ _fetch_state_dict,
+ _load_lora_into_text_encoder,
+)
from .lora_conversion_utils import (
+ _convert_bfl_flux_control_lora_to_diffusers,
+ _convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
+ _convert_non_diffusers_lumina2_lora_to_diffusers,
+ _convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
@@ -53,17 +59,13 @@
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
-if is_transformers_available():
- from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
-
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
-LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
-LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
@@ -222,7 +224,7 @@ def lora_state_dict(
"framework": "pytorch",
}
- state_dict = cls._fetch_state_dict(
+ state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -282,7 +284,9 @@ def load_lora_into_unet(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -295,18 +299,15 @@ def load_lora_into_unet(
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
- keys = list(state_dict.keys())
- only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
- if not only_text_encoder:
- # Load the layers corresponding to UNet.
- logger.info(f"Loading {cls.unet_name}.")
- unet.load_attn_procs(
- state_dict,
- network_alphas=network_alphas,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- )
+ logger.info(f"Loading {cls.unet_name}.")
+ unet.load_lora_adapter(
+ state_dict,
+ prefix=cls.unet_name,
+ network_alphas=network_alphas,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
def load_lora_into_text_encoder(
@@ -341,109 +342,21 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- peft_kwargs = {}
- if low_cpu_mem_usage:
- if not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
- if not is_transformers_version(">", "4.45.2"):
- # Note from sayakpaul: It's not in `transformers` stable yet.
- # https://github.com/huggingface/transformers/pull/33725/
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
- )
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- from peft import LoraConfig
-
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
- # their prefixes.
- keys = list(state_dict.keys())
- prefix = cls.text_encoder_name if prefix is None else prefix
-
- # Safe prefix to check with.
- if any(cls.text_encoder_name in key for key in keys):
- # Load the layers corresponding to text encoder and make necessary adjustments.
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
- text_encoder_lora_state_dict = {
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
- }
-
- if len(text_encoder_lora_state_dict) > 0:
- logger.info(f"Loading {prefix}.")
- rank = {}
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
-
- # convert state dict
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
-
- for name, _ in text_encoder_attn_modules(text_encoder):
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- for name, _ in text_encoder_mlp_modules(text_encoder):
- for module in ("fc1", "fc2"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- if network_alphas is not None:
- alpha_keys = [
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
- ]
- network_alphas = {
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
- }
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"]:
- if is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<", "0.9.0"):
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(text_encoder)
-
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- # inject LoRA layers and load the state dict
- # in transformers we automatically check whether the adapter name is already in use or not
- text_encoder.load_adapter(
- adapter_name=adapter_name,
- adapter_state_dict=text_encoder_lora_state_dict,
- peft_config=lora_config,
- **peft_kwargs,
- )
-
- # scale LoRA layers with `lora_scale`
- scale_lora_layers(text_encoder, weight=lora_scale)
-
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
-
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
def save_lora_weights(
@@ -539,7 +452,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
@@ -560,7 +477,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
- super().unfuse_lora(components=components)
+ super().unfuse_lora(components=components, **kwargs)
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
@@ -601,7 +518,9 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -641,31 +560,26 @@ def load_lora_weights(
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
- if len(text_encoder_state_dict) > 0:
- self.load_lora_into_text_encoder(
- text_encoder_state_dict,
- network_alphas=network_alphas,
- text_encoder=self.text_encoder,
- prefix="text_encoder",
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- )
-
- text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
- if len(text_encoder_2_state_dict) > 0:
- self.load_lora_into_text_encoder(
- text_encoder_2_state_dict,
- network_alphas=network_alphas,
- text_encoder=self.text_encoder_2,
- prefix="text_encoder_2",
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix=self.text_encoder_name,
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix=f"{self.text_encoder_name}_2",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
@validate_hf_hub_args
@@ -744,7 +658,7 @@ def lora_state_dict(
"framework": "pytorch",
}
- state_dict = cls._fetch_state_dict(
+ state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -805,7 +719,9 @@ def load_lora_into_unet(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -818,18 +734,15 @@ def load_lora_into_unet(
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
- keys = list(state_dict.keys())
- only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
- if not only_text_encoder:
- # Load the layers corresponding to UNet.
- logger.info(f"Loading {cls.unet_name}.")
- unet.load_attn_procs(
- state_dict,
- network_alphas=network_alphas,
- adapter_name=adapter_name,
- _pipeline=_pipeline,
- low_cpu_mem_usage=low_cpu_mem_usage,
- )
+ logger.info(f"Loading {cls.unet_name}.")
+ unet.load_lora_adapter(
+ state_dict,
+ prefix=cls.unet_name,
+ network_alphas=network_alphas,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -865,109 +778,21 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- peft_kwargs = {}
- if low_cpu_mem_usage:
- if not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
- if not is_transformers_version(">", "4.45.2"):
- # Note from sayakpaul: It's not in `transformers` stable yet.
- # https://github.com/huggingface/transformers/pull/33725/
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
- )
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- from peft import LoraConfig
-
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
- # their prefixes.
- keys = list(state_dict.keys())
- prefix = cls.text_encoder_name if prefix is None else prefix
-
- # Safe prefix to check with.
- if any(cls.text_encoder_name in key for key in keys):
- # Load the layers corresponding to text encoder and make necessary adjustments.
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
- text_encoder_lora_state_dict = {
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
- }
-
- if len(text_encoder_lora_state_dict) > 0:
- logger.info(f"Loading {prefix}.")
- rank = {}
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
-
- # convert state dict
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
-
- for name, _ in text_encoder_attn_modules(text_encoder):
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- for name, _ in text_encoder_mlp_modules(text_encoder):
- for module in ("fc1", "fc2"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- if network_alphas is not None:
- alpha_keys = [
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
- ]
- network_alphas = {
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
- }
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"]:
- if is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<", "0.9.0"):
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(text_encoder)
-
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- # inject LoRA layers and load the state dict
- # in transformers we automatically check whether the adapter name is already in use or not
- text_encoder.load_adapter(
- adapter_name=adapter_name,
- adapter_state_dict=text_encoder_lora_state_dict,
- peft_config=lora_config,
- **peft_kwargs,
- )
-
- # scale LoRA layers with `lora_scale`
- scale_lora_layers(text_encoder, weight=lora_scale)
-
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
-
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
def save_lora_weights(
@@ -1010,11 +835,11 @@ def save_lora_weights(
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
+ state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
@@ -1071,7 +896,11 @@ def fuse_lora(
```
"""
super().fuse_lora(
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
)
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
@@ -1092,7 +921,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
- super().unfuse_lora(components=components)
+ super().unfuse_lora(components=components, **kwargs)
class SD3LoraLoaderMixin(LoraBaseMixin):
@@ -1182,7 +1011,7 @@ def lora_state_dict(
"framework": "pytorch",
}
- state_dict = cls._fetch_state_dict(
+ state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -1226,7 +1055,9 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -1257,32 +1088,26 @@ def load_lora_weights(
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
-
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
- if len(text_encoder_state_dict) > 0:
- self.load_lora_into_text_encoder(
- text_encoder_state_dict,
- network_alphas=None,
- text_encoder=self.text_encoder,
- prefix="text_encoder",
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- )
-
- text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
- if len(text_encoder_2_state_dict) > 0:
- self.load_lora_into_text_encoder(
- text_encoder_2_state_dict,
- network_alphas=None,
- text_encoder=self.text_encoder_2,
- prefix="text_encoder_2",
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder,
+ prefix=self.text_encoder_name,
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=None,
+ text_encoder=self.text_encoder_2,
+ prefix=f"{self.text_encoder_name}_2",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
def load_lora_into_transformer(
@@ -1301,94 +1126,24 @@ def load_lora_into_transformer(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
-
- keys = list(state_dict.keys())
-
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
- state_dict = {
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
- }
-
- if len(state_dict.keys()) > 0:
- # check with first key if is not in peft format
- first_key = next(iter(state_dict.keys()))
- if "lora_A" not in first_key:
- state_dict = convert_unet_state_dict_to_peft(state_dict)
-
- if adapter_name in getattr(transformer, "peft_config", {}):
- raise ValueError(
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
- )
-
- rank = {}
- for key, val in state_dict.items():
- if "lora_B" in key:
- rank[key] = val.shape[1]
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(transformer)
-
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
- # otherwise loading LoRA weights will lead to an error
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- peft_kwargs = {}
- if is_peft_version(">=", "0.13.1"):
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
-
- warn_msg = ""
- if incompatible_keys is not None:
- # Check only for unexpected keys.
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
- if unexpected_keys:
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
- if lora_unexpected_keys:
- warn_msg = (
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
- f" {', '.join(lora_unexpected_keys)}. "
- )
-
- # Filter missing keys specific to the current adapter.
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
- if missing_keys:
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
- if lora_missing_keys:
- warn_msg += (
- f"Loading adapter weights from state_dict led to missing keys in the model:"
- f" {', '.join(lora_missing_keys)}."
- )
-
- if warn_msg:
- logger.warning(warn_msg)
-
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1424,115 +1179,28 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- peft_kwargs = {}
- if low_cpu_mem_usage:
- if not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
- if not is_transformers_version(">", "4.45.2"):
- # Note from sayakpaul: It's not in `transformers` stable yet.
- # https://github.com/huggingface/transformers/pull/33725/
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
- )
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- from peft import LoraConfig
-
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
- # their prefixes.
- keys = list(state_dict.keys())
- prefix = cls.text_encoder_name if prefix is None else prefix
-
- # Safe prefix to check with.
- if any(cls.text_encoder_name in key for key in keys):
- # Load the layers corresponding to text encoder and make necessary adjustments.
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
- text_encoder_lora_state_dict = {
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
- }
-
- if len(text_encoder_lora_state_dict) > 0:
- logger.info(f"Loading {prefix}.")
- rank = {}
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
-
- # convert state dict
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
-
- for name, _ in text_encoder_attn_modules(text_encoder):
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- for name, _ in text_encoder_mlp_modules(text_encoder):
- for module in ("fc1", "fc2"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- if network_alphas is not None:
- alpha_keys = [
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
- ]
- network_alphas = {
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
- }
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"]:
- if is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<", "0.9.0"):
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(text_encoder)
-
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- # inject LoRA layers and load the state dict
- # in transformers we automatically check whether the adapter name is already in use or not
- text_encoder.load_adapter(
- adapter_name=adapter_name,
- adapter_state_dict=text_encoder_lora_state_dict,
- peft_config=lora_config,
- **peft_kwargs,
- )
-
- # scale LoRA layers with `lora_scale`
- scale_lora_layers(text_encoder, weight=lora_scale)
-
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
-
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
- transformer_lora_layers: Dict[str, torch.nn.Module] = None,
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
@@ -1581,7 +1249,6 @@ def save_lora_weights(
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
- # Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
@@ -1591,6 +1258,7 @@ def save_lora_weights(
safe_serialization=safe_serialization,
)
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
@@ -1631,9 +1299,14 @@ def fuse_lora(
```
"""
super().fuse_lora(
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
)
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r"""
Reverses the effect of
@@ -1647,12 +1320,12 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
- super().unfuse_lora(components=components)
+ super().unfuse_lora(components=components, **kwargs)
class FluxLoraLoaderMixin(LoraBaseMixin):
@@ -1666,6 +1339,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
_lora_loadable_modules = ["transformer", "text_encoder"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
+ _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
@classmethod
@validate_hf_hub_args
@@ -1742,7 +1416,7 @@ def lora_state_dict(
"framework": "pytorch",
}
- state_dict = cls._fetch_state_dict(
+ state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -1775,6 +1449,11 @@ def lora_state_dict(
# xlabs doesn't use `alpha`.
return (state_dict, None) if return_alphas else state_dict
+ is_bfl_control = any("query_norm.scale" in k for k in state_dict)
+ if is_bfl_control:
+ state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
+ return (state_dict, None) if return_alphas else state_dict
+
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
@@ -1819,7 +1498,9 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1839,32 +1520,75 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
- is_correct_format = all("lora" in key for key in state_dict.keys())
- if not is_correct_format:
+ has_lora_keys = any("lora" in key for key in state_dict.keys())
+
+ # Flux Control LoRAs also have norm keys
+ has_norm_keys = any(
+ norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
+ )
+
+ if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint.")
+ transformer_lora_state_dict = {
+ k: state_dict.get(k)
+ for k in list(state_dict.keys())
+ if k.startswith(f"{self.transformer_name}.") and "lora" in k
+ }
+ transformer_norm_state_dict = {
+ k: state_dict.pop(k)
+ for k in list(state_dict.keys())
+ if k.startswith(f"{self.transformer_name}.")
+ and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
+ }
+
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
+ has_param_with_expanded_shape = False
+ if len(transformer_lora_state_dict) > 0:
+ has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
+ transformer, transformer_lora_state_dict, transformer_norm_state_dict
+ )
+
+ if has_param_with_expanded_shape:
+ logger.info(
+ "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
+ "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
+ "To get a comprehensive list of parameter names that were modified, enable debug logging."
+ )
+ if len(transformer_lora_state_dict) > 0:
+ transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
+ transformer=transformer, lora_state_dict=transformer_lora_state_dict
+ )
+ for k in transformer_lora_state_dict:
+ state_dict.update({k: transformer_lora_state_dict[k]})
+
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ transformer=transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
- if len(text_encoder_state_dict) > 0:
- self.load_lora_into_text_encoder(
- text_encoder_state_dict,
- network_alphas=network_alphas,
- text_encoder=self.text_encoder,
- prefix="text_encoder",
- lora_scale=self.lora_scale,
- adapter_name=adapter_name,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
+ if len(transformer_norm_state_dict) > 0:
+ transformer._transformer_norm_layers = self._load_norm_into_transformer(
+ transformer_norm_state_dict,
+ transformer=transformer,
+ discard_original_layers=False,
)
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix=self.text_encoder_name,
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
@classmethod
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
@@ -1881,104 +1605,83 @@ def load_lora_into_transformer(
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`SD3Transformer2DModel`):
+ transformer (`FluxTransformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
-
- keys = list(state_dict.keys())
-
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
- state_dict = {
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
- }
-
- if len(state_dict.keys()) > 0:
- # check with first key if is not in peft format
- first_key = next(iter(state_dict.keys()))
- if "lora_A" not in first_key:
- state_dict = convert_unet_state_dict_to_peft(state_dict)
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=network_alphas,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
- if adapter_name in getattr(transformer, "peft_config", {}):
- raise ValueError(
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
- )
+ @classmethod
+ def _load_norm_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ prefix=None,
+ discard_original_layers=False,
+ ) -> Dict[str, torch.Tensor]:
+ # Remove prefix if present
+ prefix = prefix or cls.transformer_name
+ for key in list(state_dict.keys()):
+ if key.split(".")[0] == prefix:
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
+
+ # Find invalid keys
+ transformer_state_dict = transformer.state_dict()
+ transformer_keys = set(transformer_state_dict.keys())
+ state_dict_keys = set(state_dict.keys())
+ extra_keys = list(state_dict_keys - transformer_keys)
+
+ if extra_keys:
+ logger.warning(
+ f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
+ )
- rank = {}
- for key, val in state_dict.items():
- if "lora_B" in key:
- rank[key] = val.shape[1]
+ for key in extra_keys:
+ state_dict.pop(key)
- if network_alphas is not None and len(network_alphas) >= 1:
- prefix = cls.transformer_name
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
+ # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
+ overwritten_layers_state_dict = {}
+ if not discard_original_layers:
+ for key in state_dict.keys():
+ overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(transformer)
-
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
- # otherwise loading LoRA weights will lead to an error
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- peft_kwargs = {}
- if is_peft_version(">=", "0.13.1"):
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
-
- warn_msg = ""
- if incompatible_keys is not None:
- # Check only for unexpected keys.
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
- if unexpected_keys:
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
- if lora_unexpected_keys:
- warn_msg = (
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
- f" {', '.join(lora_unexpected_keys)}. "
- )
+ logger.info(
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
+ 'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
+ "If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
+ )
- # Filter missing keys specific to the current adapter.
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
- if missing_keys:
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
- if lora_missing_keys:
- warn_msg += (
- f"Loading adapter weights from state_dict led to missing keys in the model:"
- f" {', '.join(lora_missing_keys)}."
- )
+ # We can't load with strict=True because the current state_dict does not contain all the transformer keys
+ incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
- if warn_msg:
- logger.warning(warn_msg)
+ # We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
+ if unexpected_keys:
+ if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
+ raise ValueError(
+ f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
+ )
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ return overwritten_layers_state_dict
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2014,109 +1717,21 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- peft_kwargs = {}
- if low_cpu_mem_usage:
- if not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
- if not is_transformers_version(">", "4.45.2"):
- # Note from sayakpaul: It's not in `transformers` stable yet.
- # https://github.com/huggingface/transformers/pull/33725/
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
- )
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- from peft import LoraConfig
-
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
- # their prefixes.
- keys = list(state_dict.keys())
- prefix = cls.text_encoder_name if prefix is None else prefix
-
- # Safe prefix to check with.
- if any(cls.text_encoder_name in key for key in keys):
- # Load the layers corresponding to text encoder and make necessary adjustments.
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
- text_encoder_lora_state_dict = {
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
- }
-
- if len(text_encoder_lora_state_dict) > 0:
- logger.info(f"Loading {prefix}.")
- rank = {}
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
-
- # convert state dict
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
-
- for name, _ in text_encoder_attn_modules(text_encoder):
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- for name, _ in text_encoder_mlp_modules(text_encoder):
- for module in ("fc1", "fc2"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- if network_alphas is not None:
- alpha_keys = [
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
- ]
- network_alphas = {
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
- }
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"]:
- if is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<", "0.9.0"):
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(text_encoder)
-
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- # inject LoRA layers and load the state dict
- # in transformers we automatically check whether the adapter name is already in use or not
- text_encoder.load_adapter(
- adapter_name=adapter_name,
- adapter_state_dict=text_encoder_lora_state_dict,
- peft_config=lora_config,
- **peft_kwargs,
- )
-
- # scale LoRA layers with `lora_scale`
- scale_lora_layers(text_encoder, weight=lora_scale)
-
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
-
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
@@ -2173,10 +1788,9 @@ def save_lora_weights(
safe_serialization=safe_serialization,
)
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
- components: List[str] = ["transformer", "text_encoder"],
+ components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -2213,8 +1827,25 @@ def fuse_lora(
pipeline.fuse_lora(lora_scale=0.7)
```
"""
+
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
+ if (
+ hasattr(transformer, "_transformer_norm_layers")
+ and isinstance(transformer._transformer_norm_layers, dict)
+ and len(transformer._transformer_norm_layers.keys()) > 0
+ ):
+ logger.info(
+ "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
+ "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
+ "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
+ )
+
super().fuse_lora(
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
@@ -2231,10 +1862,276 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
- super().unfuse_lora(components=components)
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
+ super().unfuse_lora(components=components, **kwargs)
-# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
+ # We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
+ def unload_lora_weights(self, reset_to_overwritten_params=False):
+ """
+ Unloads the LoRA parameters.
+
+ Args:
+ reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
+ to their original params. Refer to the [Flux
+ documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
+ >>> pipeline.unload_lora_weights()
+ >>> ...
+ ```
+ """
+ super().unload_lora_weights()
+
+ transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
+ if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
+ transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
+ transformer._transformer_norm_layers = None
+
+ if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
+ overwritten_params = transformer._overwritten_params
+ module_names = set()
+
+ for param_name in overwritten_params:
+ if param_name.endswith(".weight"):
+ module_names.add(param_name.replace(".weight", ""))
+
+ for name, module in transformer.named_modules():
+ if isinstance(module, torch.nn.Linear) and name in module_names:
+ module_weight = module.weight.data
+ module_bias = module.bias.data if module.bias is not None else None
+ bias = module_bias is not None
+
+ parent_module_name, _, current_module_name = name.rpartition(".")
+ parent_module = transformer.get_submodule(parent_module_name)
+
+ current_param_weight = overwritten_params[f"{name}.weight"]
+ in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
+ with torch.device("meta"):
+ original_module = torch.nn.Linear(
+ in_features,
+ out_features,
+ bias=bias,
+ dtype=module_weight.dtype,
+ )
+
+ tmp_state_dict = {"weight": current_param_weight}
+ if module_bias is not None:
+ tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
+ original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
+ setattr(parent_module, current_module_name, original_module)
+
+ del tmp_state_dict
+
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
+ new_value = int(current_param_weight.shape[1])
+ old_value = getattr(transformer.config, attribute_name)
+ setattr(transformer.config, attribute_name, new_value)
+ logger.info(
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
+ )
+
+ @classmethod
+ def _maybe_expand_transformer_param_shape_or_error_(
+ cls,
+ transformer: torch.nn.Module,
+ lora_state_dict=None,
+ norm_state_dict=None,
+ prefix=None,
+ ) -> bool:
+ """
+ Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
+ generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
+ """
+ state_dict = {}
+ if lora_state_dict is not None:
+ state_dict.update(lora_state_dict)
+ if norm_state_dict is not None:
+ state_dict.update(norm_state_dict)
+
+ # Remove prefix if present
+ prefix = prefix or cls.transformer_name
+ for key in list(state_dict.keys()):
+ if key.split(".")[0] == prefix:
+ state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
+
+ # Expand transformer parameter shapes if they don't match lora
+ has_param_with_shape_update = False
+ overwritten_params = {}
+
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
+ for name, module in transformer.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ module_weight = module.weight.data
+ module_bias = module.bias.data if module.bias is not None else None
+ bias = module_bias is not None
+
+ lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
+ lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
+ lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
+ if lora_A_weight_name not in state_dict:
+ continue
+
+ in_features = state_dict[lora_A_weight_name].shape[1]
+ out_features = state_dict[lora_B_weight_name].shape[0]
+
+ # Model maybe loaded with different quantization schemes which may flatten the params.
+ # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
+ # preserve weight shape.
+ module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)
+
+ # This means there's no need for an expansion in the params, so we simply skip.
+ if tuple(module_weight_shape) == (out_features, in_features):
+ continue
+
+ # TODO (sayakpaul): We still need to consider if the module we're expanding is
+ # quantized and handle it accordingly if that is the case.
+ module_out_features, module_in_features = module_weight.shape
+ debug_message = ""
+ if in_features > module_in_features:
+ debug_message += (
+ f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
+ f"checkpoint contains higher number of features than expected. The number of input_features will be "
+ f"expanded from {module_in_features} to {in_features}"
+ )
+ if out_features > module_out_features:
+ debug_message += (
+ ", and the number of output features will be "
+ f"expanded from {module_out_features} to {out_features}."
+ )
+ else:
+ debug_message += "."
+ if debug_message:
+ logger.debug(debug_message)
+
+ if out_features > module_out_features or in_features > module_in_features:
+ has_param_with_shape_update = True
+ parent_module_name, _, current_module_name = name.rpartition(".")
+ parent_module = transformer.get_submodule(parent_module_name)
+
+ with torch.device("meta"):
+ expanded_module = torch.nn.Linear(
+ in_features, out_features, bias=bias, dtype=module_weight.dtype
+ )
+ # Only weights are expanded and biases are not. This is because only the input dimensions
+ # are changed while the output dimensions remain the same. The shape of the weight tensor
+ # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
+ # explains the reason why only weights are expanded.
+ new_weight = torch.zeros_like(
+ expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
+ )
+ slices = tuple(slice(0, dim) for dim in module_weight.shape)
+ new_weight[slices] = module_weight
+ tmp_state_dict = {"weight": new_weight}
+ if module_bias is not None:
+ tmp_state_dict["bias"] = module_bias
+ expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
+
+ setattr(parent_module, current_module_name, expanded_module)
+
+ del tmp_state_dict
+
+ if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
+ attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
+ new_value = int(expanded_module.weight.data.shape[1])
+ old_value = getattr(transformer.config, attribute_name)
+ setattr(transformer.config, attribute_name, new_value)
+ logger.info(
+ f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
+ )
+
+ # For `unload_lora_weights()`.
+ # TODO: this could lead to more memory overhead if the number of overwritten params
+ # are large. Should be revisited later and tackled through a `discard_original_layers` arg.
+ overwritten_params[f"{current_module_name}.weight"] = module_weight
+ if module_bias is not None:
+ overwritten_params[f"{current_module_name}.bias"] = module_bias
+
+ if len(overwritten_params) > 0:
+ transformer._overwritten_params = overwritten_params
+
+ return has_param_with_shape_update
+
+ @classmethod
+ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
+ expanded_module_names = set()
+ transformer_state_dict = transformer.state_dict()
+ prefix = f"{cls.transformer_name}."
+
+ lora_module_names = [
+ key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
+ ]
+ lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
+ lora_module_names = sorted(set(lora_module_names))
+ transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
+ unexpected_modules = set(lora_module_names) - set(transformer_module_names)
+ if unexpected_modules:
+ logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
+
+ is_peft_loaded = getattr(transformer, "peft_config", None) is not None
+ for k in lora_module_names:
+ if k in unexpected_modules:
+ continue
+
+ base_param_name = (
+ f"{k.replace(prefix, '')}.base_layer.weight"
+ if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
+ else f"{k.replace(prefix, '')}.weight"
+ )
+ base_weight_param = transformer_state_dict[base_param_name]
+ lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
+
+ # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
+ base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)
+
+ if base_module_shape[1] > lora_A_param.shape[1]:
+ shape = (lora_A_param.shape[0], base_weight_param.shape[1])
+ expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
+ expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
+ lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
+ expanded_module_names.add(k)
+ elif base_module_shape[1] < lora_A_param.shape[1]:
+ raise NotImplementedError(
+ f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
+ )
+
+ if expanded_module_names:
+ logger.info(
+ f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
+ )
+
+ return lora_state_dict
+
+ @staticmethod
+ def _calculate_module_shape(
+ model: "torch.nn.Module",
+ base_module: "torch.nn.Linear" = None,
+ base_weight_param_name: str = None,
+ ) -> "torch.Size":
+ def _get_weight_shape(weight: torch.Tensor):
+ return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
+
+ if base_module is not None:
+ return _get_weight_shape(base_module.weight)
+ elif base_weight_param_name is not None:
+ if not base_weight_param_name.endswith(".weight"):
+ raise ValueError(
+ f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
+ )
+ module_path = base_weight_param_name.rsplit(".weight", 1)[0]
+ submodule = get_submodule_by_name(model, module_path)
+ return _get_weight_shape(submodule.weight)
+
+ raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
+
+
+# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
_lora_loadable_modules = ["transformer", "text_encoder"]
@@ -2242,7 +2139,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
text_encoder_name = TEXT_ENCODER_NAME
@classmethod
- def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
+ # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -2255,93 +2155,29 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- unet (`UNet2DConditionModel`):
- The UNet model to load the LoRA layers into.
+ transformer (`UVit2DModel`):
+ The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
-
- keys = list(state_dict.keys())
-
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
- state_dict = {
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
- }
-
- if network_alphas is not None:
- alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
- network_alphas = {
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
- }
-
- if len(state_dict.keys()) > 0:
- if adapter_name in getattr(transformer, "peft_config", {}):
- raise ValueError(
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
- )
-
- rank = {}
- for key, val in state_dict.items():
- if "lora_B" in key:
- rank[key] = val.shape[1]
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(transformer)
-
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
- # otherwise loading LoRA weights will lead to an error
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
-
- warn_msg = ""
- if incompatible_keys is not None:
- # Check only for unexpected keys.
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
- if unexpected_keys:
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
- if lora_unexpected_keys:
- warn_msg = (
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
- f" {', '.join(lora_unexpected_keys)}. "
- )
-
- # Filter missing keys specific to the current adapter.
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
- if missing_keys:
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
- if lora_missing_keys:
- warn_msg += (
- f"Loading adapter weights from state_dict led to missing keys in the model:"
- f" {', '.join(lora_missing_keys)}."
- )
-
- if warn_msg:
- logger.warning(warn_msg)
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=network_alphas,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2377,109 +2213,21 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
- if not USE_PEFT_BACKEND:
- raise ValueError("PEFT backend is required for this method.")
-
- peft_kwargs = {}
- if low_cpu_mem_usage:
- if not is_peft_version(">=", "0.13.1"):
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
- )
- if not is_transformers_version(">", "4.45.2"):
- # Note from sayakpaul: It's not in `transformers` stable yet.
- # https://github.com/huggingface/transformers/pull/33725/
- raise ValueError(
- "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
- )
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- from peft import LoraConfig
-
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
- # their prefixes.
- keys = list(state_dict.keys())
- prefix = cls.text_encoder_name if prefix is None else prefix
-
- # Safe prefix to check with.
- if any(cls.text_encoder_name in key for key in keys):
- # Load the layers corresponding to text encoder and make necessary adjustments.
- text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
- text_encoder_lora_state_dict = {
- k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
- }
-
- if len(text_encoder_lora_state_dict) > 0:
- logger.info(f"Loading {prefix}.")
- rank = {}
- text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
-
- # convert state dict
- text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
-
- for name, _ in text_encoder_attn_modules(text_encoder):
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- for name, _ in text_encoder_mlp_modules(text_encoder):
- for module in ("fc1", "fc2"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in text_encoder_lora_state_dict:
- continue
- rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
-
- if network_alphas is not None:
- alpha_keys = [
- k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
- ]
- network_alphas = {
- k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
- }
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"]:
- if is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<", "0.9.0"):
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(text_encoder)
-
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- # inject LoRA layers and load the state dict
- # in transformers we automatically check whether the adapter name is already in use or not
- text_encoder.load_adapter(
- adapter_name=adapter_name,
- adapter_state_dict=text_encoder_lora_state_dict,
- peft_config=lora_config,
- **peft_kwargs,
- )
-
- # scale LoRA layers with `lora_scale`
- scale_lora_layers(text_encoder, weight=lora_scale)
-
- text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
-
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ _load_lora_into_text_encoder(
+ state_dict=state_dict,
+ network_alphas=network_alphas,
+ lora_scale=lora_scale,
+ text_encoder=text_encoder,
+ prefix=prefix,
+ text_encoder_name=cls.text_encoder_name,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
def save_lora_weights(
@@ -2538,7 +2286,7 @@ def save_lora_weights(
class CogVideoXLoraLoaderMixin(LoraBaseMixin):
r"""
- Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`].
+ Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
"""
_lora_loadable_modules = ["transformer"]
@@ -2619,7 +2367,7 @@ def lora_state_dict(
"framework": "pytorch",
}
- state_dict = cls._fetch_state_dict(
+ state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -2658,7 +2406,9 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -2691,7 +2441,7 @@ def load_lora_weights(
)
@classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
@@ -2703,99 +2453,29 @@ def load_lora_into_transformer(
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
- transformer (`SD3Transformer2DModel`):
+ transformer (`CogVideoXTransformer3DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
-
- keys = list(state_dict.keys())
-
- transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
- state_dict = {
- k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
- }
-
- if len(state_dict.keys()) > 0:
- # check with first key if is not in peft format
- first_key = next(iter(state_dict.keys()))
- if "lora_A" not in first_key:
- state_dict = convert_unet_state_dict_to_peft(state_dict)
-
- if adapter_name in getattr(transformer, "peft_config", {}):
- raise ValueError(
- f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
- )
-
- rank = {}
- for key, val in state_dict.items():
- if "lora_B" in key:
- rank[key] = val.shape[1]
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- lora_config_kwargs.pop("use_dora")
- lora_config = LoraConfig(**lora_config_kwargs)
-
- # adapter_name
- if adapter_name is None:
- adapter_name = get_adapter_name(transformer)
-
- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
- # otherwise loading LoRA weights will lead to an error
- is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
-
- peft_kwargs = {}
- if is_peft_version(">=", "0.13.1"):
- peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
-
- inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
- incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
-
- warn_msg = ""
- if incompatible_keys is not None:
- # Check only for unexpected keys.
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
- if unexpected_keys:
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
- if lora_unexpected_keys:
- warn_msg = (
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
- f" {', '.join(lora_unexpected_keys)}. "
- )
-
- # Filter missing keys specific to the current adapter.
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
- if missing_keys:
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
- if lora_missing_keys:
- warn_msg += (
- f"Loading adapter weights from state_dict led to missing keys in the model:"
- f" {', '.join(lora_missing_keys)}."
- )
-
- if warn_msg:
- logger.warning(warn_msg)
-
- # Offload back.
- if is_model_cpu_offload:
- _pipeline.enable_model_cpu_offload()
- elif is_sequential_cpu_offload:
- _pipeline.enable_sequential_cpu_offload()
- # Unsafe code />
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
@classmethod
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
@@ -2845,10 +2525,9 @@ def save_lora_weights(
safe_serialization=safe_serialization,
)
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
- components: List[str] = ["transformer", "text_encoder"],
+ components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -2886,11 +2565,2215 @@ def fuse_lora(
```
"""
super().fuse_lora(
- components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
)
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class Mochi1LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`MochiTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class LTXVideoLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`LTXVideoTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class SanaLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SanaTransformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading original format HunyuanVideo LoRA checkpoints.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
+ if is_original_hunyuan_video:
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`HunyuanVideoTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class Lumina2LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ # conversion.
+ non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
+ if non_diffusers:
+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`Lumina2Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class WanLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ if any(k.startswith("diffusion_model.") for k in state_dict):
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ @classmethod
+ def _maybe_expand_t2v_lora_for_i2v(
+ cls,
+ transformer: torch.nn.Module,
+ state_dict,
+ ):
+ if transformer.config.image_dim is None:
+ return state_dict
+
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
+
+ if is_i2v_lora:
+ return state_dict
+
+ for i in range(num_blocks):
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
+ )
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
+ )
+
+ return state_dict
+
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ state_dict=state_dict,
+ )
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`WanTransformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class CogView4LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
+ def load_lora_into_transformer(
+ cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`CogView4Transformer2DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ if transformer_lora_layers:
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
@@ -2904,11 +4787,8 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
"""
- super().unfuse_lora(components=components)
+ super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index d1c6721512fa..8b52cf63456c 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,30 +13,106 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
+import os
from functools import partial
+from pathlib import Path
from typing import Dict, List, Optional, Union
+import safetensors
+import torch
+
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
+ convert_unet_state_dict_to_peft,
delete_adapter_layers,
+ get_adapter_name,
+ get_peft_kwargs,
is_peft_available,
+ is_peft_version,
+ logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
+from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales
+logger = logging.get_logger(__name__)
+
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
+ "ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
+ "MochiTransformer3DModel": lambda model_cls, weights: weights,
+ "HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
+ "LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
+ "SanaTransformer2DModel": lambda model_cls, weights: weights,
+ "Lumina2Transformer2DModel": lambda model_cls, weights: weights,
+ "WanTransformer3DModel": lambda model_cls, weights: weights,
+ "CogView4Transformer2DModel": lambda model_cls, weights: weights,
}
+def _maybe_adjust_config(config):
+ """
+ We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
+ (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
+ method removes the ambiguity by following what is described here:
+ https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
+ """
+ # Track keys that have been explicitly removed to prevent re-adding them.
+ deleted_keys = set()
+
+ rank_pattern = config["rank_pattern"].copy()
+ target_modules = config["target_modules"]
+ original_r = config["r"]
+
+ for key in list(rank_pattern.keys()):
+ key_rank = rank_pattern[key]
+
+ # try to detect ambiguity
+ # `target_modules` can also be a str, in which case this loop would loop
+ # over the chars of the str. The technically correct way to match LoRA keys
+ # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
+ # But this cuts it for now.
+ exact_matches = [mod for mod in target_modules if mod == key]
+ substring_matches = [mod for mod in target_modules if key in mod and mod != key]
+ ambiguous_key = key
+
+ if exact_matches and substring_matches:
+ # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
+ config["r"] = key_rank
+ # remove the ambiguous key from `rank_pattern` and record it as deleted
+ del config["rank_pattern"][key]
+ deleted_keys.add(key)
+ # For substring matches, add them with the original rank only if they haven't been assigned already
+ for mod in substring_matches:
+ if mod not in config["rank_pattern"] and mod not in deleted_keys:
+ config["rank_pattern"][mod] = original_r
+
+ # Update the rest of the target modules with the original rank if not already set and not deleted
+ for mod in target_modules:
+ if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
+ config["rank_pattern"][mod] = original_r
+
+ # Handle alphas to deal with cases like:
+ # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
+ has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
+ if has_different_ranks:
+ config["lora_alpha"] = config["r"]
+ alpha_pattern = {}
+ for module_name, rank in config["rank_pattern"].items():
+ alpha_pattern[module_name] = rank
+ config["alpha_pattern"] = alpha_pattern
+
+ return config
+
+
class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -53,6 +129,305 @@ class PeftAdapterMixin:
_hf_peft_config_loaded = False
+ @classmethod
+ # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
+ def _optionally_disable_offloading(cls, _pipeline):
+ """
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
+
+ Args:
+ _pipeline (`DiffusionPipeline`):
+ The pipeline to disable offloading for.
+
+ Returns:
+ tuple:
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
+ """
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
+
+ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
+ r"""
+ Loads a LoRA adapter into the underlying model.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ prefix (`str`, *optional*): Prefix to filter the state dict.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ network_alphas (`Dict[str, float]`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ """
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ adapter_name = kwargs.pop("adapter_name", None)
+ network_alphas = kwargs.pop("network_alphas", None)
+ _pipeline = kwargs.pop("_pipeline", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
+ allow_pickle = False
+
+ if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ state_dict = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ if network_alphas is not None and prefix is None:
+ raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
+
+ if prefix is not None:
+ state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
+
+ if len(state_dict) > 0:
+ if adapter_name in getattr(self, "peft_config", {}):
+ raise ValueError(
+ f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
+ )
+
+ # check with first key if is not in peft format
+ first_key = next(iter(state_dict.keys()))
+ if "lora_A" not in first_key:
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
+
+ rank = {}
+ for key, val in state_dict.items():
+ # Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
+ # Bias layers in LoRA only have a single dimension
+ if "lora_B" in key and val.ndim > 1:
+ # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
+ rank[key] = val.shape[1]
+
+ if network_alphas is not None and len(network_alphas) >= 1:
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
+ # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
+ lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
+
+ if "use_dora" in lora_config_kwargs:
+ if lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError(
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<", "0.9.0"):
+ lora_config_kwargs.pop("use_dora")
+
+ if "lora_bias" in lora_config_kwargs:
+ if lora_config_kwargs["lora_bias"]:
+ if is_peft_version("<=", "0.13.2"):
+ raise ValueError(
+ "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<=", "0.13.2"):
+ lora_config_kwargs.pop("lora_bias")
+
+ lora_config = LoraConfig(**lora_config_kwargs)
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(self)
+
+ # =", "0.13.1"):
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
+
+ # To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
+ # we should also delete the `peft_config` associated to the `adapter_name`.
+ try:
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
+ # Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
+ if not self._hf_peft_config_loaded:
+ self._hf_peft_config_loaded = True
+ except Exception as e:
+ # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
+ if hasattr(self, "peft_config"):
+ for module in self.modules():
+ if isinstance(module, BaseTunerLayer):
+ active_adapters = module.active_adapters
+ for active_adapter in active_adapters:
+ if adapter_name in active_adapter:
+ module.delete_adapter(adapter_name)
+
+ self.peft_config.pop(adapter_name)
+ logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
+ raise
+
+ warn_msg = ""
+ if incompatible_keys is not None:
+ # Check only for unexpected keys.
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
+ if lora_unexpected_keys:
+ warn_msg = (
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
+ f" {', '.join(lora_unexpected_keys)}. "
+ )
+
+ # Filter missing keys specific to the current adapter.
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
+ if missing_keys:
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
+ if lora_missing_keys:
+ warn_msg += (
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
+ f" {', '.join(lora_missing_keys)}."
+ )
+
+ if warn_msg:
+ logger.warning(warn_msg)
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ if prefix is not None and not state_dict:
+ logger.warning(
+ f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
+ "This is safe to ignore if LoRA state dict didn't originally have any "
+ f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
+ "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
+ "https://github.com/huggingface/diffusers/issues/new"
+ )
+
+ def save_lora_adapter(
+ self,
+ save_directory,
+ adapter_name: str = "default",
+ upcast_before_saving: bool = False,
+ safe_serialization: bool = True,
+ weight_name: Optional[str] = None,
+ ):
+ """
+ Save the LoRA parameters corresponding to the underlying model.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
+ underlying model has multiple adapters loaded.
+ upcast_before_saving (`bool`, defaults to `False`):
+ Whether to cast the underlying model to `torch.float32` before serialization.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
+ """
+ from peft.utils import get_peft_model_state_dict
+
+ from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
+
+ if adapter_name is None:
+ adapter_name = get_adapter_name(self)
+
+ if adapter_name not in getattr(self, "peft_config", {}):
+ raise ValueError(f"Adapter name {adapter_name} not found in the model.")
+
+ lora_layers_to_save = get_peft_model_state_dict(
+ self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
+ )
+ if os.path.isfile(save_directory):
+ raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ if safe_serialization:
+
+ def save_function(weights, filename):
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+ else:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if weight_name is None:
+ if safe_serialization:
+ weight_name = LORA_WEIGHT_NAME_SAFE
+ else:
+ weight_name = LORA_WEIGHT_NAME
+
+ # TODO: we could consider saving the `peft_config` as well.
+ save_path = Path(save_directory, weight_name).as_posix()
+ save_function(lora_layers_to_save, save_path)
+ logger.info(f"Model weights saved in {save_path}")
+
def set_adapters(
self,
adapter_names: Union[List[str], str],
diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py
index c0cbfc713857..c2843fc7406a 100644
--- a/src/diffusers/loaders/single_file.py
+++ b/src/diffusers/loaders/single_file.py
@@ -19,6 +19,7 @@
from huggingface_hub import snapshot_download
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
+from typing_extensions import Self
from ..utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
@@ -60,6 +61,7 @@ def load_single_file_sub_model(
local_files_only=False,
torch_dtype=None,
is_legacy_loading=False,
+ disable_mmap=False,
**kwargs,
):
if is_pipeline_module:
@@ -106,6 +108,7 @@ def load_single_file_sub_model(
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
+ disable_mmap=disable_mmap,
**kwargs,
)
@@ -267,7 +270,7 @@ class FromSingleFileMixin:
@classmethod
@validate_hf_hub_args
- def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
@@ -308,6 +311,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
+ disable_mmap ('bool', *optional*, defaults to 'False'):
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
+ is on a network mount or hard drive.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -329,7 +335,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
>>> # Enable float16 and move to GPU
>>> pipeline = StableDiffusionPipeline.from_single_file(
- ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
+ ... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
... torch_dtype=torch.float16,
... )
>>> pipeline.to("cuda")
@@ -355,9 +361,16 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
+ disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ torch_dtype = torch.float32
+ logger.warning(
+ f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+ )
+
# We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None)
@@ -383,6 +396,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
+ disable_mmap=disable_mmap,
)
if config is None:
@@ -504,6 +518,7 @@ def load_module(name, value):
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
+ disable_mmap=disable_mmap,
**kwargs,
)
except SingleFileComponentError as e:
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index 3fe1abfbead5..dafdb3c26ddc 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -17,18 +17,31 @@
from contextlib import nullcontext
from typing import Optional
+import torch
from huggingface_hub.utils import validate_hf_hub_args
+from typing_extensions import Self
+from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
+ convert_auraflow_transformer_checkpoint_to_diffusers,
+ convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
+ convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
+ convert_ltx_transformer_checkpoint_to_diffusers,
+ convert_ltx_vae_checkpoint_to_diffusers,
+ convert_lumina2_to_diffusers,
+ convert_mochi_transformer_checkpoint_to_diffusers,
+ convert_sana_transformer_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
+ convert_wan_transformer_to_diffusers,
+ convert_wan_vae_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
@@ -42,7 +55,7 @@
if is_accelerate_available():
- from accelerate import init_empty_weights
+ from accelerate import dispatch_model, init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
@@ -82,6 +95,43 @@
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
+ "LTXVideoTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "AutoencoderKLLTXVideo": {
+ "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
+ "default_subfolder": "vae",
+ },
+ "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
+ "MochiTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "HunyuanVideoTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "AuraFlowTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "Lumina2Transformer2DModel": {
+ "checkpoint_mapping_fn": convert_lumina2_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "SanaTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "WanTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "AutoencoderKLWan": {
+ "checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
+ "default_subfolder": "vae",
+ },
}
@@ -114,7 +164,7 @@ class FromOriginalModelMixin:
@classmethod
@validate_hf_hub_args
- def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
+ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
r"""
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
is set in evaluation mode (`model.eval()`) by default.
@@ -158,6 +208,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
+ disable_mmap ('bool', *optional*, defaults to 'False'):
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
@@ -201,7 +254,17 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
local_files_only = kwargs.pop("local_files_only", None)
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
+ config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
+ quantization_config = kwargs.pop("quantization_config", None)
+ device = kwargs.pop("device", None)
+ disable_mmap = kwargs.pop("disable_mmap", False)
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ torch_dtype = torch.float32
+ logger.warning(
+ f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+ )
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
@@ -214,12 +277,20 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
+ disable_mmap=disable_mmap,
)
+ if quantization_config is not None:
+ hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
+ hf_quantizer.validate_environment()
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
+
+ else:
+ hf_quantizer = None
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
- if original_config:
+ if original_config is not None:
if "config_mapping_fn" in mapping_functions:
config_mapping_fn = mapping_functions["config_mapping_fn"]
else:
@@ -243,7 +314,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
)
else:
- if config:
+ if config is not None:
if isinstance(config, str):
default_pretrained_model_config_name = config
else:
@@ -269,6 +340,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
pretrained_model_name_or_path=default_pretrained_model_config_name,
subfolder=subfolder,
local_files_only=local_files_only,
+ token=token,
+ revision=config_revision,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
@@ -295,9 +368,43 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
with ctx():
model = cls.from_config(diffusers_model_config)
- if is_accelerate_available():
- unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
+ # Check if `_keep_in_fp32_modules` is not None
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
+ )
+ if use_keep_in_fp32_modules:
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
+ if not isinstance(keep_in_fp32_modules, list):
+ keep_in_fp32_modules = [keep_in_fp32_modules]
+
+ else:
+ keep_in_fp32_modules = []
+
+ if hf_quantizer is not None:
+ hf_quantizer.preprocess_model(
+ model=model,
+ device_map=None,
+ state_dict=diffusers_format_checkpoint,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ )
+ device_map = None
+ if is_accelerate_available():
+ param_device = torch.device(device) if device else torch.device("cpu")
+ empty_state_dict = model.state_dict()
+ unexpected_keys = [
+ param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
+ ]
+ device_map = {"": param_device}
+ load_model_dict_into_meta(
+ model,
+ diffusers_format_checkpoint,
+ dtype=torch_dtype,
+ device_map=device_map,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ unexpected_keys=unexpected_keys,
+ )
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -310,9 +417,17 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
- if torch_dtype is not None:
+ if hf_quantizer is not None:
+ hf_quantizer.postprocess_model(model)
+ model.hf_quantizer = hf_quantizer
+
+ if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype)
model.eval()
+ if device_map is not None:
+ device_map_kwargs = {"device_map": device_map}
+ dispatch_model(model, **device_map_kwargs)
+
return model
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index 236fbd0c2295..42aee4a84822 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -62,7 +62,14 @@
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
"upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
- "controlnet": "control_model.time_embed.0.weight",
+ "controlnet": [
+ "control_model.time_embed.0.weight",
+ "controlnet_cond_embedding.conv_in.weight",
+ ],
+ # TODO: find non-Diffusers keys for controlnet_xl
+ "controlnet_xl": "add_embedding.linear_1.weight",
+ "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight",
"playground-v2-5": "edm_mean",
"inpainting": "model.diffusion_model.input_blocks.0.0.weight",
"clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
@@ -74,16 +81,50 @@
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
- "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
+ "sd3": [
+ "joint_blocks.0.context_block.adaLN_modulation.1.bias",
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
+ ],
+ "sd35_large": [
+ "joint_blocks.37.x_block.mlp.fc1.weight",
+ "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
+ ],
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
+ "auraflow": [
+ "double_layers.0.attn.w2q.weight",
+ "double_layers.0.attn.w1q.weight",
+ "cond_seq_linear.weight",
+ "t_embedder.mlp.0.weight",
+ ],
"flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
+ "ltx-video": [
+ "model.diffusion_model.patchify_proj.weight",
+ "model.diffusion_model.transformer_blocks.27.scale_shift_table",
+ "patchify_proj.weight",
+ "transformer_blocks.27.scale_shift_table",
+ "vae.per_channel_statistics.mean-of-means",
+ ],
+ "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
+ "autoencoder-dc-sana": "encoder.project_in.conv.bias",
+ "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
+ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
+ "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
+ "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
+ "sana": [
+ "blocks.0.cross_attn.q_linear.weight",
+ "blocks.0.cross_attn.q_linear.bias",
+ "blocks.0.cross_attn.kv_linear.weight",
+ "blocks.0.cross_attn.kv_linear.bias",
+ ],
+ "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
+ "wan_vae": "decoder.middle.0.residual.0.gamma",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -95,6 +136,9 @@
"inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
"inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
"controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
+ "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"},
+ "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"},
+ "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"},
"v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
"v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
"stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
@@ -113,14 +157,37 @@
"sd3": {
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
},
+ "sd35_large": {
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
+ },
+ "sd35_medium": {
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium",
+ },
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
+ "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
+ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
+ "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
+ "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
+ "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
+ "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
+ "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
+ "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
+ "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
+ "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
+ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
+ "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
+ "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
+ "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
+ "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
+ "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
+ "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
}
# Use to configure model sample size when original config is provided
@@ -133,6 +200,7 @@
"inpainting": 512,
"inpainting_v2": 512,
"controlnet": 512,
+ "instruct-pix2pix": 512,
"v2": 768,
"v1": 512,
}
@@ -334,12 +402,14 @@ def load_single_file_checkpoint(
cache_dir=None,
local_files_only=None,
revision=None,
+ disable_mmap=False,
):
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
+ user_agent = {"file_type": "single_file", "framework": "pytorch"}
pretrained_model_link_or_path = _get_model_file(
repo_id,
weights_name=weights_name,
@@ -349,9 +419,10 @@ def load_single_file_checkpoint(
local_files_only=local_files_only,
token=token,
revision=revision,
+ user_agent=user_agent,
)
- checkpoint = load_state_dict(pretrained_model_link_or_path)
+ checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
@@ -477,8 +548,16 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
model_type = "upscale"
- elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
- model_type = "controlnet"
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
+ if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
+ if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
+ model_type = "controlnet_xl_large"
+ elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
+ model_type = "controlnet_xl_mid"
+ else:
+ model_type = "controlnet_xl_small"
+ else:
+ model_type = "controlnet"
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
@@ -504,8 +583,21 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "stable_cascade_stage_b"
- elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
- model_type = "sd3"
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
+ checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
+ ):
+ if "model.diffusion_model.pos_embed" in checkpoint:
+ key = "model.diffusion_model.pos_embed"
+ else:
+ key = "pos_embed"
+
+ if checkpoint[key].shape[1] == 36864:
+ model_type = "sd3"
+ elif checkpoint[key].shape[1] == 147456:
+ model_type = "sd35_medium"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
+ model_type = "sd35_large"
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
@@ -530,9 +622,78 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
- model_type = "flux-dev"
+ if "model.diffusion_model.img_in.weight" in checkpoint:
+ key = "model.diffusion_model.img_in.weight"
+ else:
+ key = "img_in.weight"
+
+ if checkpoint[key].shape[1] == 384:
+ model_type = "flux-fill"
+ elif checkpoint[key].shape[1] == 128:
+ model_type = "flux-depth"
+ else:
+ model_type = "flux-dev"
else:
model_type = "flux-schnell"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
+ if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
+ model_type = "ltx-video-0.9.1"
+ else:
+ model_type = "ltx-video"
+
+ elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
+ encoder_key = "encoder.project_in.conv.conv.bias"
+ decoder_key = "decoder.project_in.main.conv.weight"
+
+ if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
+ model_type = "autoencoder-dc-f32c32-sana"
+
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
+ model_type = "autoencoder-dc-f32c32"
+
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
+ model_type = "autoencoder-dc-f64c128"
+
+ else:
+ model_type = "autoencoder-dc-f128c512"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
+ model_type = "mochi-1-preview"
+
+ elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
+ model_type = "hunyuan-video"
+
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
+ model_type = "auraflow"
+
+ elif (
+ CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
+ ):
+ model_type = "instruct-pix2pix"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
+ model_type = "lumina2"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
+ model_type = "sana"
+
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
+ if "model.diffusion_model.patch_embedding.weight" in checkpoint:
+ target_key = "model.diffusion_model.patch_embedding.weight"
+ else:
+ target_key = "patch_embedding.weight"
+
+ if checkpoint[target_key].shape[0] == 1536:
+ model_type = "wan-t2v-1.3B"
+ elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
+ model_type = "wan-t2v-14B"
+ else:
+ model_type = "wan-i2v-14B"
+ elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
+ # All Wan models use the same VAE so we can use the same default model repo to fetch the config
+ model_type = "wan-t2v-14B"
else:
model_type = "v1"
@@ -1065,6 +1226,9 @@ def convert_controlnet_checkpoint(
config,
**kwargs,
):
+ # Return checkpoint if it's already been converted
+ if "time_embedding.linear_1.weight" in checkpoint:
+ return checkpoint
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if "time_embed.0.weight" in checkpoint:
@@ -1316,8 +1480,8 @@ def convert_open_clip_checkpoint(
if text_proj_key in checkpoint:
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
- elif hasattr(text_model.config, "projection_dim"):
- text_proj_dim = text_model.config.projection_dim
+ elif hasattr(text_model.config, "hidden_size"):
+ text_proj_dim = text_model.config.hidden_size
else:
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
@@ -1461,18 +1625,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
if is_accelerate_available():
- unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
+ load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
- _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
-
- if model._keys_to_ignore_on_load_unexpected is not None:
- for pat in model._keys_to_ignore_on_load_unexpected:
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
-
- if len(unexpected_keys) > 0:
- logger.warning(
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
- )
+ model.load_state_dict(diffusers_format_checkpoint, strict=False)
if torch_dtype is not None:
model.to(torch_dtype)
@@ -1670,6 +1825,28 @@ def swap_scale_shift(weight, dim):
return new_weight
+def swap_proj_gate(weight):
+ proj, gate = weight.chunk(2, dim=0)
+ new_weight = torch.cat([gate, proj], dim=0)
+ return new_weight
+
+
+def get_attn2_layers(state_dict):
+ attn2_layers = []
+ for key in state_dict.keys():
+ if "attn2." in key:
+ # Extract the layer number from the key
+ layer_num = int(key.split(".")[1])
+ attn2_layers.append(layer_num)
+
+ return tuple(sorted(set(attn2_layers)))
+
+
+def get_caption_projection_dim(state_dict):
+ caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
+ return caption_projection_dim
+
+
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
@@ -1678,7 +1855,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
- caption_projection_dim = 1536
+ dual_attention_layers = get_attn2_layers(checkpoint)
+
+ caption_projection_dim = get_caption_projection_dim(checkpoint)
+ has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
# Positional and patch embeddings.
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
@@ -1735,6 +1915,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk norm
+ if has_qk_norm:
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn.ln_k.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.attn.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.attn.ln_k.weight"
+ )
+
# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -1750,6 +1945,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
f"joint_blocks.{i}.context_block.attn.proj.bias"
)
+ if i in dual_attention_layers:
+ # Q, K, V
+ sample_q2, sample_k2, sample_v2 = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
+ )
+ sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
+
+ # qk norm
+ if has_qk_norm:
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
+ )
+
+ # output projections.
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.proj.bias"
+ )
+
# norms.
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
@@ -1857,16 +2084,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
if is_accelerate_available():
- unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
- if model._keys_to_ignore_on_load_unexpected is not None:
- for pat in model._keys_to_ignore_on_load_unexpected:
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
-
- if len(unexpected_keys) > 0:
- logger.warning(
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
- )
-
+ load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint)
@@ -1907,6 +2125,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
+
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
@@ -2098,3 +2317,945 @@ def swap_scale_shift(weight):
)
return converted_state_dict
+
+
+def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
+
+ TRANSFORMER_KEYS_RENAME_DICT = {
+ "model.diffusion_model.": "",
+ "patchify_proj": "proj_in",
+ "adaln_single": "time_embed",
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+
+ for key in list(converted_state_dict.keys()):
+ new_key = key
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
+
+ def remove_keys_(key: str, state_dict):
+ state_dict.pop(key)
+
+ VAE_KEYS_RENAME_DICT = {
+ # common
+ "vae.": "",
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0",
+ "up_blocks.2": "up_blocks.1.upsamplers.0",
+ "up_blocks.3": "up_blocks.1",
+ "up_blocks.4": "up_blocks.2.conv_in",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.conv_in",
+ "up_blocks.8": "up_blocks.3.upsamplers.0",
+ "up_blocks.9": "up_blocks.3",
+ # encoder
+ "down_blocks.0": "down_blocks.0",
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
+ "down_blocks.2": "down_blocks.0.conv_out",
+ "down_blocks.3": "down_blocks.1",
+ "down_blocks.4": "down_blocks.1.downsamplers.0",
+ "down_blocks.5": "down_blocks.1.conv_out",
+ "down_blocks.6": "down_blocks.2",
+ "down_blocks.7": "down_blocks.2.downsamplers.0",
+ "down_blocks.8": "down_blocks.3",
+ "down_blocks.9": "mid_block",
+ # common
+ "conv_shortcut": "conv_shortcut.conv",
+ "res_blocks": "resnets",
+ "norm3.norm": "norm3",
+ "per_channel_statistics.mean-of-means": "latents_mean",
+ "per_channel_statistics.std-of-means": "latents_std",
+ }
+
+ VAE_091_RENAME_DICT = {
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
+ "up_blocks.2": "up_blocks.0",
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
+ "up_blocks.4": "up_blocks.1",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
+ "up_blocks.8": "up_blocks.3",
+ # common
+ "last_time_embedder": "time_embedder",
+ "last_scale_shift_table": "scale_shift_table",
+ }
+
+ VAE_SPECIAL_KEYS_REMAP = {
+ "per_channel_statistics.channel": remove_keys_,
+ "per_channel_statistics.mean-of-means": remove_keys_,
+ "per_channel_statistics.mean-of-stds": remove_keys_,
+ "timestep_scale_multiplier": remove_keys_,
+ }
+
+ if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
+ VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
+
+ for key in list(converted_state_dict.keys()):
+ new_key = key
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ def remap_qkv_(key: str, state_dict):
+ qkv = state_dict.pop(key)
+ q, k, v = torch.chunk(qkv, 3, dim=0)
+ parent_module, _, _ = key.rpartition(".qkv.conv.weight")
+ state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
+ state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
+ state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
+
+ def remap_proj_conv_(key: str, state_dict):
+ parent_module, _, _ = key.rpartition(".proj.conv.weight")
+ state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
+
+ AE_KEYS_RENAME_DICT = {
+ # common
+ "main.": "",
+ "op_list.": "",
+ "context_module": "attn",
+ "local_module": "conv_out",
+ # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
+ # If there were more scales, there would be more layers, so a loop would be better to handle this
+ "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
+ "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
+ "depth_conv.conv": "conv_depth",
+ "inverted_conv.conv": "conv_inverted",
+ "point_conv.conv": "conv_point",
+ "point_conv.norm": "norm",
+ "conv.conv.": "conv.",
+ "conv1.conv": "conv1",
+ "conv2.conv": "conv2",
+ "conv2.norm": "norm",
+ "proj.norm": "norm_out",
+ # encoder
+ "encoder.project_in.conv": "encoder.conv_in",
+ "encoder.project_out.0.conv": "encoder.conv_out",
+ "encoder.stages": "encoder.down_blocks",
+ # decoder
+ "decoder.project_in.conv": "decoder.conv_in",
+ "decoder.project_out.0": "decoder.norm_out",
+ "decoder.project_out.2.conv": "decoder.conv_out",
+ "decoder.stages": "decoder.up_blocks",
+ }
+
+ AE_F32C32_F64C128_F128C512_KEYS = {
+ "encoder.project_in.conv": "encoder.conv_in.conv",
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
+ }
+
+ AE_SPECIAL_KEYS_REMAP = {
+ "qkv.conv.weight": remap_qkv_,
+ "proj.conv.weight": remap_proj_conv_,
+ }
+ if "encoder.project_in.conv.bias" not in converted_state_dict:
+ AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)
+
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+
+ # Comfy checkpoints add this prefix
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ # Convert patch_embed
+ converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
+
+ # Convert time_embed
+ converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
+ converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
+ converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
+ converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
+ converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
+ converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
+ converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
+ converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
+ converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
+ converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
+ converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
+
+ # Convert transformer blocks
+ num_layers = 48
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ old_prefix = f"blocks.{i}."
+
+ # norm1
+ converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
+ converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
+ if i < num_layers - 1:
+ converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
+ old_prefix + "mod_y.weight"
+ )
+ converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
+ old_prefix + "mod_y.bias"
+ )
+ else:
+ converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
+ old_prefix + "mod_y.weight"
+ )
+ converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
+ old_prefix + "mod_y.bias"
+ )
+
+ # Visual attention
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
+ converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
+ converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
+ converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
+ old_prefix + "attn.q_norm_x.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
+ old_prefix + "attn.k_norm_x.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
+ old_prefix + "attn.proj_x.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
+
+ # Context attention
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
+ q, k, v = qkv_weight.chunk(3, dim=0)
+
+ converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
+ converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
+ converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
+ converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
+ old_prefix + "attn.q_norm_y.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
+ old_prefix + "attn.k_norm_y.weight"
+ )
+ if i < num_layers - 1:
+ converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
+ old_prefix + "attn.proj_y.weight"
+ )
+ converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
+ old_prefix + "attn.proj_y.bias"
+ )
+
+ # MLP
+ converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
+ checkpoint.pop(old_prefix + "mlp_x.w1.weight")
+ )
+ converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
+ if i < num_layers - 1:
+ converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
+ checkpoint.pop(old_prefix + "mlp_y.w1.weight")
+ )
+ converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
+ old_prefix + "mlp_y.w2.weight"
+ )
+
+ # Output layers
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+
+ converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
+
+ return converted_state_dict
+
+
+def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
+ def remap_norm_scale_shift_(key, state_dict):
+ weight = state_dict.pop(key)
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
+
+ def remap_txt_in_(key, state_dict):
+ def rename_key(key):
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
+ new_key = new_key.replace("txt_in", "context_embedder")
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
+ new_key = new_key.replace("mlp", "ff")
+ return new_key
+
+ if "self_attn_qkv" in key:
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
+ else:
+ state_dict[rename_key(key)] = state_dict.pop(key)
+
+ def remap_img_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
+
+ def remap_txt_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
+
+ def remap_single_transformer_blocks_(key, state_dict):
+ hidden_size = 3072
+
+ if "linear1.weight" in key:
+ linear1_weight = state_dict.pop(key)
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
+ state_dict[f"{new_key}.attn.to_q.weight"] = q
+ state_dict[f"{new_key}.attn.to_k.weight"] = k
+ state_dict[f"{new_key}.attn.to_v.weight"] = v
+ state_dict[f"{new_key}.proj_mlp.weight"] = mlp
+
+ elif "linear1.bias" in key:
+ linear1_bias = state_dict.pop(key)
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
+ state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
+ state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
+ state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
+ state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
+
+ else:
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
+ new_key = new_key.replace("linear2", "proj_out")
+ new_key = new_key.replace("q_norm", "attn.norm_q")
+ new_key = new_key.replace("k_norm", "attn.norm_k")
+ state_dict[new_key] = state_dict.pop(key)
+
+ TRANSFORMER_KEYS_RENAME_DICT = {
+ "img_in": "x_embedder",
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
+ "double_blocks": "transformer_blocks",
+ "img_attn_q_norm": "attn.norm_q",
+ "img_attn_k_norm": "attn.norm_k",
+ "img_attn_proj": "attn.to_out.0",
+ "txt_attn_q_norm": "attn.norm_added_q",
+ "txt_attn_k_norm": "attn.norm_added_k",
+ "txt_attn_proj": "attn.to_add_out",
+ "img_mod.linear": "norm1.linear",
+ "img_norm1": "norm1.norm",
+ "img_norm2": "norm2",
+ "img_mlp": "ff",
+ "txt_mod.linear": "norm1_context.linear",
+ "txt_norm1": "norm1.norm",
+ "txt_norm2": "norm2_context",
+ "txt_mlp": "ff_context",
+ "self_attn_proj": "attn.to_out.0",
+ "modulation.linear": "norm.linear",
+ "pre_norm": "norm.norm",
+ "final_layer.norm_final": "norm_out.norm",
+ "final_layer.linear": "proj_out",
+ "fc1": "net.0.proj",
+ "fc2": "net.2",
+ "input_embedder": "proj_in",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "txt_in": remap_txt_in_,
+ "img_attn_qkv": remap_img_attn_qkv_,
+ "txt_attn_qkv": remap_txt_attn_qkv_,
+ "single_blocks": remap_single_transformer_blocks_,
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
+ }
+
+ def update_state_dict_(state_dict, old_key, new_key):
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ for key in list(checkpoint.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(checkpoint, key, new_key)
+
+ for key in list(checkpoint.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, checkpoint)
+
+ return checkpoint
+
+
+def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ state_dict_keys = list(checkpoint.keys())
+
+ # Handle register tokens and positional embeddings
+ converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
+
+ # Handle time step projection
+ converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
+ converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
+ converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
+ converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
+
+ # Handle context embedder
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
+
+ # Calculate the number of layers
+ def calculate_layers(keys, key_prefix):
+ layers = set()
+ for k in keys:
+ if key_prefix in k:
+ layer_num = int(k.split(".")[1]) # get the layer number
+ layers.add(layer_num)
+ return len(layers)
+
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
+
+ # MMDiT blocks
+ for i in range(mmdit_layers):
+ # Feed-forward
+ path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
+ weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for orig_k, diffuser_k in path_mapping.items():
+ for k, v in weight_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
+ f"double_layers.{i}.{orig_k}.{k}.weight", None
+ )
+
+ # Norms
+ path_mapping = {"modX": "norm1", "modC": "norm1_context"}
+ for orig_k, diffuser_k in path_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
+ f"double_layers.{i}.{orig_k}.1.weight", None
+ )
+
+ # Attentions
+ x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
+ context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
+ for attn_mapping in [x_attn_mapping, context_attn_mapping]:
+ for k, v in attn_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
+ f"double_layers.{i}.attn.{k}.weight", None
+ )
+
+ # Single-DiT blocks
+ for i in range(single_dit_layers):
+ # Feed-forward
+ mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for k, v in mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
+ f"single_layers.{i}.mlp.{k}.weight", None
+ )
+
+ # Norms
+ converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
+ f"single_layers.{i}.modCX.1.weight", None
+ )
+
+ # Attentions
+ x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
+ for k, v in x_attn_mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
+ f"single_layers.{i}.attn.{k}.weight", None
+ )
+ # Final blocks
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
+
+ # Handle the final norm layer
+ norm_weight = checkpoint.pop("modF.1.weight", None)
+ if norm_weight is not None:
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
+ else:
+ converted_state_dict["norm_out.linear.weight"] = None
+
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
+
+ return converted_state_dict
+
+
+def convert_lumina2_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+
+ # Original Lumina-Image-2 has an extra norm paramter that is unused
+ # We just remove it here
+ checkpoint.pop("norm_final.weight", None)
+
+ # Comfy checkpoints add this prefix
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ LUMINA_KEY_MAP = {
+ "cap_embedder": "time_caption_embed.caption_embedder",
+ "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1",
+ "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2",
+ "attention": "attn",
+ ".out.": ".to_out.0.",
+ "k_norm": "norm_k",
+ "q_norm": "norm_q",
+ "w1": "linear_1",
+ "w2": "linear_2",
+ "w3": "linear_3",
+ "adaLN_modulation.1": "norm1.linear",
+ }
+ ATTENTION_NORM_MAP = {
+ "attention_norm1": "norm1.norm",
+ "attention_norm2": "norm2",
+ }
+ CONTEXT_REFINER_MAP = {
+ "context_refiner.0.attention_norm1": "context_refiner.0.norm1",
+ "context_refiner.0.attention_norm2": "context_refiner.0.norm2",
+ "context_refiner.1.attention_norm1": "context_refiner.1.norm1",
+ "context_refiner.1.attention_norm2": "context_refiner.1.norm2",
+ }
+ FINAL_LAYER_MAP = {
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
+ "final_layer.linear": "norm_out.linear_2",
+ }
+
+ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
+ q_dim = 2304
+ k_dim = v_dim = 768
+
+ to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0)
+
+ return {
+ diffusers_key.replace("qkv", "to_q"): to_q,
+ diffusers_key.replace("qkv", "to_k"): to_k,
+ diffusers_key.replace("qkv", "to_v"): to_v,
+ }
+
+ for key in keys:
+ diffusers_key = key
+ for k, v in CONTEXT_REFINER_MAP.items():
+ diffusers_key = diffusers_key.replace(k, v)
+ for k, v in FINAL_LAYER_MAP.items():
+ diffusers_key = diffusers_key.replace(k, v)
+ for k, v in ATTENTION_NORM_MAP.items():
+ diffusers_key = diffusers_key.replace(k, v)
+ for k, v in LUMINA_KEY_MAP.items():
+ diffusers_key = diffusers_key.replace(k, v)
+
+ if "qkv" in diffusers_key:
+ converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key))
+ else:
+ converted_state_dict[diffusers_key] = checkpoint.pop(key)
+
+ return converted_state_dict
+
+
+def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
+
+ # Positional and patch embeddings.
+ checkpoint.pop("pos_embed")
+ converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
+
+ # Timestep embeddings.
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
+ converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
+ converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
+
+ # Caption Projection.
+ checkpoint.pop("y_embedder.y_embedding")
+ converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
+ converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
+ converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
+ converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
+ converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
+
+ for i in range(num_layers):
+ converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
+ f"blocks.{i}.scale_shift_table"
+ )
+
+ # Self-Attention
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
+
+ # Output Projections
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
+ f"blocks.{i}.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
+ f"blocks.{i}.attn.proj.bias"
+ )
+
+ # Cross-Attention
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
+ f"blocks.{i}.cross_attn.q_linear.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
+ f"blocks.{i}.cross_attn.q_linear.bias"
+ )
+
+ linear_sample_k, linear_sample_v = torch.chunk(
+ checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
+ )
+ linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
+ checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
+
+ # Output Projections
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
+ f"blocks.{i}.cross_attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
+ f"blocks.{i}.cross_attn.proj.bias"
+ )
+
+ # MLP
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
+ f"blocks.{i}.mlp.inverted_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
+ f"blocks.{i}.mlp.inverted_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
+ f"blocks.{i}.mlp.depth_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
+ f"blocks.{i}.mlp.depth_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
+ f"blocks.{i}.mlp.point_conv.conv.weight"
+ )
+
+ # Final layer
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+ converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
+
+ return converted_state_dict
+
+
+def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "cross_attn": "attn2",
+ "self_attn": "attn1",
+ ".o.": ".to_out.0.",
+ ".q.": ".to_q.",
+ ".k.": ".to_k.",
+ ".v.": ".to_v.",
+ ".k_img.": ".add_k_proj.",
+ ".v_img.": ".add_v_proj.",
+ ".norm_k_img.": ".norm_added_k.",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ # For the I2V model
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ }
+
+ for key in list(checkpoint.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+
+ converted_state_dict[new_key] = checkpoint.pop(key)
+
+ return converted_state_dict
+
+
+def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+
+ # Create mappings for specific components
+ middle_key_mapping = {
+ # Encoder middle block
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ # Decoder middle block
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ }
+
+ # Create a mapping for attention blocks
+ attention_mapping = {
+ # Encoder middle attention
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
+ # Decoder middle attention
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
+ }
+
+ # Create a mapping for the head components
+ head_mapping = {
+ # Encoder head
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
+ "encoder.head.2.bias": "encoder.conv_out.bias",
+ "encoder.head.2.weight": "encoder.conv_out.weight",
+ # Decoder head
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
+ "decoder.head.2.bias": "decoder.conv_out.bias",
+ "decoder.head.2.weight": "decoder.conv_out.weight",
+ }
+
+ # Create a mapping for the quant components
+ quant_mapping = {
+ "conv1.weight": "quant_conv.weight",
+ "conv1.bias": "quant_conv.bias",
+ "conv2.weight": "post_quant_conv.weight",
+ "conv2.bias": "post_quant_conv.bias",
+ }
+
+ # Process each key in the state dict
+ for key, value in checkpoint.items():
+ # Handle middle block keys using the mapping
+ if key in middle_key_mapping:
+ new_key = middle_key_mapping[key]
+ converted_state_dict[new_key] = value
+ # Handle attention blocks using the mapping
+ elif key in attention_mapping:
+ new_key = attention_mapping[key]
+ converted_state_dict[new_key] = value
+ # Handle head keys using the mapping
+ elif key in head_mapping:
+ new_key = head_mapping[key]
+ converted_state_dict[new_key] = value
+ # Handle quant keys using the mapping
+ elif key in quant_mapping:
+ new_key = quant_mapping[key]
+ converted_state_dict[new_key] = value
+ # Handle encoder conv1
+ elif key == "encoder.conv1.weight":
+ converted_state_dict["encoder.conv_in.weight"] = value
+ elif key == "encoder.conv1.bias":
+ converted_state_dict["encoder.conv_in.bias"] = value
+ # Handle decoder conv1
+ elif key == "decoder.conv1.weight":
+ converted_state_dict["decoder.conv_in.weight"] = value
+ elif key == "decoder.conv1.bias":
+ converted_state_dict["decoder.conv_in.bias"] = value
+ # Handle encoder downsamples
+ elif key.startswith("encoder.downsamples."):
+ # Convert to down_blocks
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
+
+ # Convert residual block naming but keep the original structure
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+
+ converted_state_dict[new_key] = value
+
+ # Handle decoder upsamples
+ elif key.startswith("decoder.upsamples."):
+ # Convert to up_blocks
+ parts = key.split(".")
+ block_idx = int(parts[2])
+
+ # Group residual blocks
+ if "residual" in key:
+ if block_idx in [0, 1, 2]:
+ new_block_idx = 0
+ resnet_idx = block_idx
+ elif block_idx in [4, 5, 6]:
+ new_block_idx = 1
+ resnet_idx = block_idx - 4
+ elif block_idx in [8, 9, 10]:
+ new_block_idx = 2
+ resnet_idx = block_idx - 8
+ elif block_idx in [12, 13, 14]:
+ new_block_idx = 3
+ resnet_idx = block_idx - 12
+ else:
+ # Keep as is for other blocks
+ converted_state_dict[key] = value
+ continue
+
+ # Convert residual block naming
+ if ".residual.0.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
+ elif ".residual.2.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
+ elif ".residual.2.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
+ elif ".residual.3.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
+ elif ".residual.6.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
+ elif ".residual.6.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
+ else:
+ new_key = key
+
+ converted_state_dict[new_key] = value
+
+ # Handle shortcut connections
+ elif ".shortcut." in key:
+ if block_idx == 4:
+ new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
+ new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
+
+ converted_state_dict[new_key] = value
+
+ # Handle upsamplers
+ elif ".resample." in key or ".time_conv." in key:
+ if block_idx == 3:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
+ elif block_idx == 7:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
+ elif block_idx == 11:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+
+ converted_state_dict[new_key] = value
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ converted_state_dict[new_key] = value
+ else:
+ # Keep other keys unchanged
+ converted_state_dict[key] = value
+
+ return converted_state_dict
diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py
index 30098c955d6b..9aeb81c3e911 100644
--- a/src/diffusers/loaders/textual_inversion.py
+++ b/src/diffusers/loaders/textual_inversion.py
@@ -40,7 +40,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
+ hf_token = kwargs.pop("hf_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
@@ -73,7 +73,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
- token=token,
+ token=hf_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -93,7 +93,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
- token=token,
+ token=hf_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -312,7 +312,7 @@ def load_textual_inversion(
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
+ hf_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -333,7 +333,7 @@ def load_textual_inversion(
from diffusers import StableDiffusionPipeline
import torch
- model_id = "runwayml/stable-diffusion-v1-5"
+ model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("sd-concepts-library/cat-toy")
@@ -352,7 +352,7 @@ def load_textual_inversion(
from diffusers import StableDiffusionPipeline
import torch
- model_id = "runwayml/stable-diffusion-v1-5"
+ model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
@@ -449,9 +449,9 @@ def load_textual_inversion(
# 7.5 Offload the model again
if is_model_cpu_offload:
- self.enable_model_cpu_offload()
+ self.enable_model_cpu_offload(device=device)
elif is_sequential_cpu_offload:
- self.enable_sequential_cpu_offload()
+ self.enable_sequential_cpu_offload(device=device)
# / Unsafe Code >
@@ -469,7 +469,7 @@ def unload_textual_inversion(
from diffusers import AutoPipelineForText2Image
import torch
- pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
+ pipeline = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
# Example 1
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
@@ -497,19 +497,19 @@ def unload_textual_inversion(
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipeline.load_textual_inversion(
state_dict["clip_l"],
- token=["", ""],
+ tokens=["", ""],
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipeline.load_textual_inversion(
state_dict["clip_g"],
- token=["", ""],
+ tokens=["", ""],
text_encoder=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer_2,
)
- # Unload explicitly from both text encoders abd tokenizers
+ # Unload explicitly from both text encoders and tokenizers
pipeline.unload_textual_inversion(
tokens=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
)
diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py
new file mode 100644
index 000000000000..38a8a7ebe266
--- /dev/null
+++ b/src/diffusers/loaders/transformer_flux.py
@@ -0,0 +1,180 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from contextlib import nullcontext
+
+from ..models.embeddings import (
+ ImageProjection,
+ MultiIPAdapterImageProjection,
+)
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from ..utils import (
+ is_accelerate_available,
+ is_torch_version,
+ logging,
+)
+
+
+if is_accelerate_available():
+ pass
+
+logger = logging.get_logger(__name__)
+
+
+class FluxTransformer2DLoadersMixin:
+ """
+ Load layers into a [`FluxTransformer2DModel`].
+ """
+
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+ if low_cpu_mem_usage:
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ else:
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ updated_state_dict = {}
+ image_projection = None
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+
+ if "proj.weight" in state_dict:
+ # IP-Adapter
+ num_image_text_embeds = 4
+ if state_dict["proj.weight"].shape[0] == 65536:
+ num_image_text_embeds = 16
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
+
+ with init_context():
+ image_projection = ImageProjection(
+ cross_attention_dim=cross_attention_dim,
+ image_embed_dim=clip_embeddings_dim,
+ num_image_text_embeds=num_image_text_embeds,
+ )
+
+ for key, value in state_dict.items():
+ diffusers_name = key.replace("proj", "image_embeds")
+ updated_state_dict[diffusers_name] = value
+
+ if not low_cpu_mem_usage:
+ image_projection.load_state_dict(updated_state_dict, strict=True)
+ else:
+ device_map = {"": self.device}
+ load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
+
+ return image_projection
+
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+ from ..models.attention_processor import (
+ FluxIPAdapterJointAttnProcessor2_0,
+ )
+
+ if low_cpu_mem_usage:
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ else:
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ # set ip-adapter cross-attention processors & load state_dict
+ attn_procs = {}
+ key_id = 0
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+ for name in self.attn_processors.keys():
+ if name.startswith("single_transformer_blocks"):
+ attn_processor_class = self.attn_processors[name].__class__
+ attn_procs[name] = attn_processor_class()
+ else:
+ cross_attention_dim = self.config.joint_attention_dim
+ hidden_size = self.inner_dim
+ attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
+ num_image_text_embeds = []
+ for state_dict in state_dicts:
+ if "proj.weight" in state_dict["image_proj"]:
+ num_image_text_embed = 4
+ if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
+ num_image_text_embed = 16
+ # IP-Adapter
+ num_image_text_embeds += [num_image_text_embed]
+
+ with init_context():
+ attn_procs[name] = attn_processor_class(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=num_image_text_embeds,
+ dtype=self.dtype,
+ device=self.device,
+ )
+
+ value_dict = {}
+ for i, state_dict in enumerate(state_dicts):
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
+ value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
+ value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
+
+ if not low_cpu_mem_usage:
+ attn_procs[name].load_state_dict(value_dict)
+ else:
+ device_map = {"": self.device}
+ dtype = self.dtype
+ load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
+
+ key_id += 1
+
+ return attn_procs
+
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
+ if not isinstance(state_dicts, list):
+ state_dicts = [state_dicts]
+
+ self.encoder_hid_proj = None
+
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+ self.set_attn_processor(attn_procs)
+
+ image_projection_layers = []
+ for state_dict in state_dicts:
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
+ )
+ image_projection_layers.append(image_projection_layer)
+
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
+ self.config.encoder_hid_dim_type = "ip_image_proj"
diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py
new file mode 100644
index 000000000000..ece17e6728fa
--- /dev/null
+++ b/src/diffusers/loaders/transformer_sd3.py
@@ -0,0 +1,170 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from contextlib import nullcontext
+from typing import Dict
+
+from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
+from ..models.embeddings import IPAdapterTimeImageProjection
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from ..utils import is_accelerate_available, is_torch_version, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class SD3Transformer2DLoadersMixin:
+ """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
+
+ def _convert_ip_adapter_attn_to_diffusers(
+ self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
+ ) -> Dict:
+ if low_cpu_mem_usage:
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ else:
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ # IP-Adapter cross attention parameters
+ hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
+ ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
+ timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
+
+ # Dict where key is transformer layer index, value is attention processor's state dict
+ # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
+ layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
+ for key, weights in state_dict.items():
+ idx, name = key.split(".", maxsplit=1)
+ layer_state_dict[int(idx)][name] = weights
+
+ # Create IP-Adapter attention processor & load state_dict
+ attn_procs = {}
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+ for idx, name in enumerate(self.attn_processors.keys()):
+ with init_context():
+ attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
+ hidden_size=hidden_size,
+ ip_hidden_states_dim=ip_hidden_states_dim,
+ head_dim=self.config.attention_head_dim,
+ timesteps_emb_dim=timesteps_emb_dim,
+ )
+
+ if not low_cpu_mem_usage:
+ attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
+ else:
+ device_map = {"": self.device}
+ load_model_dict_into_meta(
+ attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
+ )
+
+ return attn_procs
+
+ def _convert_ip_adapter_image_proj_to_diffusers(
+ self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
+ ) -> IPAdapterTimeImageProjection:
+ if low_cpu_mem_usage:
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+ else:
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
+
+ # Convert to diffusers
+ updated_state_dict = {}
+ for key, value in state_dict.items():
+ # InstantX/SD3.5-Large-IP-Adapter
+ if key.startswith("layers."):
+ idx = key.split(".")[1]
+ key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
+ key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
+ key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
+ key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
+ key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
+ key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
+ key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
+ key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
+ key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
+ updated_state_dict[key] = value
+
+ # Image projetion parameters
+ embed_dim = updated_state_dict["proj_in.weight"].shape[1]
+ output_dim = updated_state_dict["proj_out.weight"].shape[0]
+ hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
+ heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
+ num_queries = updated_state_dict["latents"].shape[1]
+ timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
+
+ # Image projection
+ with init_context():
+ image_proj = IPAdapterTimeImageProjection(
+ embed_dim=embed_dim,
+ output_dim=output_dim,
+ hidden_dim=hidden_dim,
+ heads=heads,
+ num_queries=num_queries,
+ timestep_in_dim=timestep_in_dim,
+ )
+
+ if not low_cpu_mem_usage:
+ image_proj.load_state_dict(updated_state_dict, strict=True)
+ else:
+ device_map = {"": self.device}
+ load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
+
+ return image_proj
+
+ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
+ """Sets IP-Adapter attention processors, image projection, and loads state_dict.
+
+ Args:
+ state_dict (`Dict`):
+ State dict with keys "ip_adapter", which contains parameters for attention processors, and
+ "image_proj", which contains parameters for image projection net.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
+ self.set_attn_processor(attn_procs)
+
+ self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index 2fa7732a6a3b..1d8aba900c85 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -21,7 +21,6 @@
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
-from torch import nn
from ..models.embeddings import (
ImageProjection,
@@ -31,11 +30,12 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
-from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_unet_state_dict_to_peft,
+ deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
@@ -43,13 +43,11 @@
is_torch_version,
logging,
)
+from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
-if is_accelerate_available():
- from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
-
logger = logging.get_logger(__name__)
@@ -145,7 +143,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
adapter_name = kwargs.pop("adapter_name", None)
_pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None)
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
@@ -209,6 +207,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
is_model_cpu_offload = False
is_sequential_cpu_offload = False
+ if is_lora:
+ deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
+ deprecate("load_attn_procs", "0.40.0", deprecation_message)
+
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
@@ -338,6 +340,17 @@ def _process_lora(
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
+
+ if "lora_bias" in lora_config_kwargs:
+ if lora_config_kwargs["lora_bias"]:
+ if is_peft_version("<=", "0.13.2"):
+ raise ValueError(
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<=", "0.13.2"):
+ lora_config_kwargs.pop("lora_bias")
+
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
@@ -395,27 +408,7 @@ def _optionally_disable_offloading(cls, _pipeline):
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
- is_model_cpu_offload = False
- is_sequential_cpu_offload = False
-
- if _pipeline is not None and _pipeline.hf_device_map is None:
- for _, component in _pipeline.components.items():
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
- if not is_model_cpu_offload:
- is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
- if not is_sequential_cpu_offload:
- is_sequential_cpu_offload = (
- isinstance(component._hf_hook, AlignDevicesHook)
- or hasattr(component._hf_hook, "hooks")
- and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
- )
-
- logger.info(
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
- )
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
-
- return (is_model_cpu_offload, is_sequential_cpu_offload)
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
def save_attn_procs(
self,
@@ -487,6 +480,9 @@ def save_attn_procs(
)
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
else:
+ deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
+ deprecate("save_attn_procs", "0.40.0", deprecation_message)
+
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
@@ -544,7 +540,7 @@ def _get_custom_diffusion_state_dict(self):
return state_dict
- def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -757,14 +753,16 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
- load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
+ device_map = {"": self.device}
+ load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_projection
- def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
+ IPAdapterXFormersAttnProcessor,
)
if low_cpu_mem_usage:
@@ -804,11 +802,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
-
else:
- attn_processor_class = (
- IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
- )
+ if "XFormers" in str(self.attn_processors[name].__class__):
+ attn_processor_class = IPAdapterXFormersAttnProcessor
+ else:
+ attn_processor_class = (
+ IPAdapterAttnProcessor2_0
+ if hasattr(F, "scaled_dot_product_attention")
+ else IPAdapterAttnProcessor
+ )
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
@@ -845,13 +847,14 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
else:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
- load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
+ device_map = {"": device}
+ load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
key_id += 2
return attn_procs
- def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
old mode 100644
new mode 100755
index 4dda8c36ba1c..f7d70f1d9826
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -27,23 +27,38 @@
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
+ _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
+ _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
+ _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
+ _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
+ _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
+ _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
+ _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
- _import_structure["controlnet"] = ["ControlNetModel"]
- _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
- _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
- _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
- _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
- _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
+ _import_structure["cache_utils"] = ["CacheMixin"]
+ _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
+ _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
+ _import_structure["controlnets.controlnet_hunyuan"] = [
+ "HunyuanDiT2DControlNetModel",
+ "HunyuanDiT2DMultiControlNetModel",
+ ]
+ _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
+ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
+ _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
+ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
+ _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
+ _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
+ _import_structure["transformers.consisid_transformer_3d"] = ["ConsisIDTransformer3DModel"]
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
@@ -51,13 +66,23 @@
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
+ _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
+ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
+ _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
+ _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
+ _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
+ _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
+ _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
+ _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
+ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
+ _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -70,7 +95,7 @@
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
if is_flax_available():
- _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
+ _import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
@@ -80,39 +105,67 @@
from .adapter import MultiAdapter, T2IAdapter
from .autoencoders import (
AsymmetricAutoencoderKL,
+ AutoencoderDC,
AutoencoderKL,
+ AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
+ AutoencoderKLHunyuanVideo,
+ AutoencoderKLLTXVideo,
+ AutoencoderKLMagvit,
+ AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
+ AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
VQModel,
)
- from .controlnet import ControlNetModel
- from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
- from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
- from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
- from .controlnet_sparsectrl import SparseControlNetModel
- from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
+ from .cache_utils import CacheMixin
+ from .controlnets import (
+ ControlNetModel,
+ ControlNetUnionModel,
+ ControlNetXSAdapter,
+ FluxControlNetModel,
+ FluxMultiControlNetModel,
+ HunyuanDiT2DControlNetModel,
+ HunyuanDiT2DMultiControlNetModel,
+ MultiControlNetModel,
+ MultiControlNetUnionModel,
+ SD3ControlNetModel,
+ SD3MultiControlNetModel,
+ SparseControlNetModel,
+ UNetControlNetXSModel,
+ )
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
+ AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
+ CogView4Transformer2DModel,
+ ConsisIDTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
+ EasyAnimateTransformer3DModel,
FluxTransformer2DModel,
HunyuanDiT2DModel,
+ HunyuanVideoTransformer3DModel,
LatteTransformer3DModel,
+ LTXVideoTransformer3DModel,
+ Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
+ MochiTransformer3DModel,
+ OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
+ SanaTransformer2DModel,
SD3Transformer2DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
+ WanTransformer3DModel,
)
from .unets import (
I2VGenXLUNet,
@@ -129,7 +182,7 @@
)
if is_flax_available():
- from .controlnet_flax import FlaxControlNetModel
+ from .controlnets import FlaxControlNetModel
from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py
index fb24a36bae75..42e65d898cec 100644
--- a/src/diffusers/models/activations.py
+++ b/src/diffusers/models/activations.py
@@ -18,18 +18,18 @@
from torch import nn
from ..utils import deprecate
-from ..utils.import_utils import is_torch_npu_available
+from ..utils.import_utils import is_torch_npu_available, is_torch_version
if is_torch_npu_available():
import torch_npu
-ACTIVATION_FUNCTIONS = {
- "swish": nn.SiLU(),
- "silu": nn.SiLU(),
- "mish": nn.Mish(),
- "gelu": nn.GELU(),
- "relu": nn.ReLU(),
+ACT2CLS = {
+ "swish": nn.SiLU,
+ "silu": nn.SiLU,
+ "mish": nn.Mish,
+ "gelu": nn.GELU,
+ "relu": nn.ReLU,
}
@@ -44,10 +44,10 @@ def get_activation(act_fn: str) -> nn.Module:
"""
act_fn = act_fn.lower()
- if act_fn in ACTIVATION_FUNCTIONS:
- return ACTIVATION_FUNCTIONS[act_fn]
+ if act_fn in ACT2CLS:
+ return ACT2CLS[act_fn]()
else:
- raise ValueError(f"Unsupported activation function: {act_fn}")
+ raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
class FP32SiLU(nn.Module):
@@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
- if gate.device.type != "mps":
- return F.gelu(gate, approximate=self.approximate)
- # mps: gelu is not implemented for float16
- return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
+ if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
+ # fp16 gelu not supported on mps before torch 2.0
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
+ return F.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
@@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
- if gate.device.type != "mps":
- return F.gelu(gate)
- # mps: gelu is not implemented for float16
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+ if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
+ # fp16 gelu not supported on mps before torch 2.0
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+ return F.gelu(gate)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
@@ -136,6 +136,7 @@ class SwiGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
+
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU()
@@ -163,3 +164,15 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
+
+
+class LinearActivation(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
+ super().__init__()
+
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.activation = get_activation(activation)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ return self.activation(hidden_states)
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 02ed1f965abf..93b11c2b43f0 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -19,7 +19,7 @@
from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph
-from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
+from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
@@ -188,8 +188,13 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
self._chunk_dim = dim
def forward(
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
+ joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
@@ -206,7 +211,9 @@ def forward(
# Attention.
attn_output, context_attn_output = self.attn(
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ **joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -214,7 +221,7 @@ def forward(
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
- attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
@@ -605,7 +612,6 @@ def __init__(
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
- inner_dim = int(2 * inner_dim / 3)
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
@@ -1222,6 +1228,8 @@ def __init__(
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "linear-silu":
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
self.net = nn.ModuleList([])
# project in
diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py
index 25ae5d0a5d63..246f3afaf57c 100644
--- a/src/diffusers/models/attention_flax.py
+++ b/src/diffusers/models/attention_flax.py
@@ -216,8 +216,8 @@ def __call__(self, hidden_states, context=None, deterministic=True):
hidden_states = jax_memory_efficient_attention(
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
)
-
hidden_states = hidden_states.transpose(1, 0, 2)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
else:
# compute attentions
if self.split_head_dim:
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
old mode 100644
new mode 100755
index e735c4ee7d17..34276a544160
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -20,8 +20,8 @@
from torch import nn
from ..image_processor import IPAdapterMaskProcessor
-from ..utils import deprecate, logging
-from ..utils.import_utils import is_torch_npu_available, is_xformers_available
+from ..utils import deprecate, is_torch_xla_available, logging
+from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
@@ -36,6 +36,15 @@
else:
xformers = None
+if is_torch_xla_available():
+ # flash attention pallas kernel is introduced in the torch_xla 2.3 release.
+ if is_torch_xla_version(">", "2.2"):
+ from torch_xla.experimental.custom_kernel import flash_attention
+ from torch_xla.runtime import is_spmd
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
@maybe_allow_in_graph
class Attention(nn.Module):
@@ -120,14 +129,16 @@ def __init__(
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
+ out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
+ is_causal: bool = False,
):
super().__init__()
# To prevent circular import.
- from .normalization import FP32LayerNorm, RMSNorm
+ from .normalization import FP32LayerNorm, LpNorm, RMSNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
@@ -142,8 +153,10 @@ def __init__(
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
+ self.is_causal = is_causal
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
@@ -186,14 +199,23 @@ def __init__(
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "layer_norm_across_heads":
- # Lumina applys qk norm across all heads
+ # Lumina applies qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
+ elif qk_norm == "rms_norm_across_heads":
+ # LTX applies qk norm across all heads
+ self.norm_q = RMSNorm(dim_head * heads, eps=eps)
+ self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "l2":
+ self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
+ self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
else:
- raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
+ )
if cross_attention_norm is None:
self.norm_cross = None
@@ -234,22 +256,38 @@ def __init__(
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ else:
+ self.add_q_proj = None
+ self.add_k_proj = None
+ self.add_v_proj = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
+ else:
+ self.to_out = None
if self.context_pre_only is not None and not self.context_pre_only:
- self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
+ else:
+ self.to_add_out = None
if qk_norm is not None and added_kv_proj_dim is not None:
- if qk_norm == "fp32_layer_norm":
+ if qk_norm == "layer_norm":
+ self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ elif qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ elif qk_norm == "rms_norm_across_heads":
+ # Wan applies qk norm across all heads
+ # Wan also doesn't apply a q norm
+ self.norm_added_q = None
+ self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
@@ -268,6 +306,39 @@ def __init__(
)
self.set_processor(processor)
+ def set_use_xla_flash_attention(
+ self,
+ use_xla_flash_attention: bool,
+ partition_spec: Optional[Tuple[Optional[str], ...]] = None,
+ is_flux=False,
+ ) -> None:
+ r"""
+ Set whether to use xla flash attention from `torch_xla` or not.
+
+ Args:
+ use_xla_flash_attention (`bool`):
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
+ partition_spec (`Tuple[]`, *optional*):
+ Specify the partition specification if using SPMD. Otherwise None.
+ """
+ if use_xla_flash_attention:
+ if not is_torch_xla_available:
+ raise "torch_xla is not available"
+ elif is_torch_xla_version("<", "2.3"):
+ raise "flash attention pallas kernel is supported from torch_xla version 2.3"
+ elif is_spmd() and is_torch_xla_version("<", "2.4"):
+ raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
+ else:
+ if is_flux:
+ processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
+ else:
+ processor = XLAFlashAttnProcessor2_0(partition_spec)
+ else:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
r"""
Set whether to use npu flash attention from `torch_npu` or not.
@@ -311,6 +382,17 @@ def set_use_memory_efficient_attention_xformers(
XFormersAttnAddedKVProcessor,
),
)
+ is_ip_adapter = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
+ )
+ is_joint_processor = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (
+ JointAttnProcessor2_0,
+ XFormersJointAttnProcessor,
+ ),
+ )
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and is_custom_diffusion:
@@ -333,11 +415,12 @@ def set_use_memory_efficient_attention_xformers(
else:
try:
# Make sure we can run the memory efficient attention
- _ = xformers.ops.memory_efficient_attention(
- torch.randn((1, 2, 40), device="cuda"),
- torch.randn((1, 2, 40), device="cuda"),
- torch.randn((1, 2, 40), device="cuda"),
- )
+ dtype = None
+ if attention_op is not None:
+ op_fw, op_bw = attention_op
+ dtype, *_ = op_fw.SUPPORTED_DTYPES
+ q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
+ _ = xformers.ops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
@@ -361,6 +444,21 @@ def set_use_memory_efficient_attention_xformers(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
+ elif is_ip_adapter:
+ processor = IPAdapterXFormersAttnProcessor(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ num_tokens=self.processor.num_tokens,
+ scale=self.processor.scale,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_ip"):
+ processor.to(
+ device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
+ )
+ elif is_joint_processor:
+ processor = XFormersJointAttnProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
@@ -379,6 +477,18 @@ def set_use_memory_efficient_attention_xformers(
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ elif is_ip_adapter:
+ processor = IPAdapterAttnProcessor2_0(
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ num_tokens=self.processor.num_tokens,
+ scale=self.processor.scale,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_ip"):
+ processor.to(
+ device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
+ )
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
@@ -482,7 +592,7 @@ def forward(
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
- quiet_attn_parameters = {"ip_adapter_masks"}
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
]
@@ -631,10 +741,14 @@ def prepare_attention_mask(
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ attention_mask = attention_mask.repeat_interleave(
+ head_size, dim=0, output_size=attention_mask.shape[0] * head_size
+ )
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
- attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+ attention_mask = attention_mask.repeat_interleave(
+ head_size, dim=1, output_size=attention_mask.shape[1] * head_size
+ )
return attention_mask
@@ -697,7 +811,11 @@ def fuse_projections(self, fuse=True):
self.to_kv.bias.copy_(concatenated_bias)
# handle added projections for SD3 and others.
- if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
+ if (
+ getattr(self, "add_q_proj", None) is not None
+ and getattr(self, "add_k_proj", None) is not None
+ and getattr(self, "add_v_proj", None) is not None
+ ):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
@@ -717,6 +835,269 @@ def fuse_projections(self, fuse=True):
self.fused_projections = fuse
+class SanaMultiscaleAttentionProjection(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_attention_heads: int,
+ kernel_size: int,
+ ) -> None:
+ super().__init__()
+
+ channels = 3 * in_channels
+ self.proj_in = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size,
+ padding=kernel_size // 2,
+ groups=channels,
+ bias=False,
+ )
+ self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+ return hidden_states
+
+
+class SanaMultiscaleLinearAttention(nn.Module):
+ r"""Lightweight multi-scale linear attention"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_attention_heads: Optional[int] = None,
+ attention_head_dim: int = 8,
+ mult: float = 1.0,
+ norm_type: str = "batch_norm",
+ kernel_sizes: Tuple[int, ...] = (5,),
+ eps: float = 1e-15,
+ residual_connection: bool = False,
+ ):
+ super().__init__()
+
+ # To prevent circular import
+ from .normalization import get_normalization
+
+ self.eps = eps
+ self.attention_head_dim = attention_head_dim
+ self.norm_type = norm_type
+ self.residual_connection = residual_connection
+
+ num_attention_heads = (
+ int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
+ )
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
+ self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
+ self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
+
+ self.to_qkv_multiscale = nn.ModuleList()
+ for kernel_size in kernel_sizes:
+ self.to_qkv_multiscale.append(
+ SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
+ )
+
+ self.nonlinearity = nn.ReLU()
+ self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
+ self.norm_out = get_normalization(norm_type, num_features=out_channels)
+
+ self.processor = SanaMultiscaleAttnProcessor2_0()
+
+ def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
+ scores = torch.matmul(value, key.transpose(-1, -2))
+ hidden_states = torch.matmul(scores, query)
+
+ hidden_states = hidden_states.to(dtype=torch.float32)
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
+ return hidden_states
+
+ def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
+ scores = torch.matmul(key.transpose(-1, -2), query)
+ scores = scores.to(dtype=torch.float32)
+ scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
+ hidden_states = torch.matmul(value, scores.to(value.dtype))
+ return hidden_states
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.processor(self, hidden_states)
+
+
+class MochiAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim: int,
+ added_kv_proj_dim: int,
+ processor: "MochiAttnProcessor2_0",
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_proj_bias: bool = True,
+ out_dim: Optional[int] = None,
+ out_context_dim: Optional[int] = None,
+ out_bias: bool = True,
+ context_pre_only: bool = False,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ from .normalization import MochiRMSNorm
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.out_context_dim = out_context_dim if out_context_dim else query_dim
+ self.context_pre_only = context_pre_only
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+
+ self.norm_q = MochiRMSNorm(dim_head, eps, True)
+ self.norm_k = MochiRMSNorm(dim_head, eps, True)
+ self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
+ self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ if not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
+
+ self.processor = processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+
+
+class MochiAttnProcessor2_0:
+ """Attention processor used in Mochi."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: "MochiAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ if image_rotary_emb is not None:
+
+ def apply_rotary_emb(x, freqs_cos, freqs_sin):
+ x_even = x[..., 0::2].float()
+ x_odd = x[..., 1::2].float()
+
+ cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
+ sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
+
+ return torch.stack([cos, sin], dim=-1).flatten(-2)
+
+ query = apply_rotary_emb(query, *image_rotary_emb)
+ key = apply_rotary_emb(key, *image_rotary_emb)
+
+ query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
+ encoder_query, encoder_key, encoder_value = (
+ encoder_query.transpose(1, 2),
+ encoder_key.transpose(1, 2),
+ encoder_value.transpose(1, 2),
+ )
+
+ sequence_length = query.size(2)
+ encoder_sequence_length = encoder_query.size(2)
+ total_length = sequence_length + encoder_sequence_length
+
+ batch_size, heads, _, dim = query.shape
+ attn_outputs = []
+ for idx in range(batch_size):
+ mask = attention_mask[idx][None, :]
+ valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
+
+ valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
+ valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
+ valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
+
+ valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
+ valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
+ valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
+
+ attn_output = F.scaled_dot_product_attention(
+ valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
+ )
+ valid_sequence_length = attn_output.size(2)
+ attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
+ attn_outputs.append(attn_output)
+
+ hidden_states = torch.cat(attn_outputs, dim=0)
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+
+ hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
+ (sequence_length, encoder_sequence_length), dim=1
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if hasattr(attn, "to_add_out"):
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
@@ -1041,7 +1422,7 @@ class JointAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
@@ -1136,6 +1517,7 @@ def __call__(
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
residual = hidden_states
@@ -1521,92 +1903,84 @@ def __call__(
return hidden_states, encoder_hidden_states
-class AuraFlowAttnProcessor2_0:
- """Attention processor used typically in processing Aura Flow."""
+class XFormersJointAttnProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
- raise ImportError(
- "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
- )
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
- batch_size = hidden_states.shape[0]
+ residual = hidden_states
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- # Reshape.
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
- query = query.view(batch_size, -1, attn.heads, head_dim)
- key = key.view(batch_size, -1, attn.heads, head_dim)
- value = value.view(batch_size, -1, attn.heads, head_dim)
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
- # Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
- # Concatenate the projections.
+ # `context` projections.
if encoder_hidden_states is not None:
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- )
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- )
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
-
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
- # Attention.
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
- # Split the attention outputs.
if encoder_hidden_states is not None:
+ # Split the attention outputs.
hidden_states, encoder_hidden_states = (
- hidden_states[:, encoder_hidden_states.shape[1] :],
- hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
- if encoder_hidden_states is not None:
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
@@ -1614,19 +1988,206 @@ def __call__(
return hidden_states
-class FusedAuraFlowAttnProcessor2_0:
- """Attention processor used typically in processing Aura Flow with fused projections."""
+class AllegroAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
+ """
def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ "AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
- hidden_states: torch.FloatTensor,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None and not attn.is_cross_attention:
+ from .embeddings import apply_rotary_emb_allegro
+
+ query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
+ key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # Reshape.
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # Apply QK norm.
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Concatenate the projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
+
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Attention.
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FusedAuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow with fused projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
@@ -1778,7 +2339,10 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -1792,6 +2356,7 @@ def __call__(
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
@@ -1799,13 +2364,13 @@ def __call__(
return hidden_states
-class FusedFluxAttnProcessor2_0:
+class FluxAttnProcessor2_0_NPU:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
)
def __call__(
@@ -1819,9 +2384,9 @@ def __call__(
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -1836,15 +2401,11 @@ def __call__(
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
@@ -1872,7 +2433,23 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ if query.dtype in (torch.float16, torch.bfloat16):
+ hidden_states = torch_npu.npu_fusion_attention(
+ query,
+ key,
+ value,
+ attn.heads,
+ input_layout="BNSD",
+ pse=None,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ else:
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -1893,33 +2470,381 @@ def __call__(
return hidden_states
-class CogVideoXAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
- query and key vectors, but does not include spatial normalization.
- """
+class FusedFluxAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ raise ImportError(
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
def __call__(
self,
attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- text_seq_length = encoder_hidden_states.size(1)
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
- if attention_mask is not None:
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FusedFluxAttnProcessor2_0_NPU:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ if query.dtype in (torch.float16, torch.bfloat16):
+ hidden_states = torch_npu.npu_fusion_attention(
+ query,
+ key,
+ value,
+ attn.heads,
+ input_layout="BNSD",
+ pse=None,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ else:
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
+ """Flux Attention processor for IP-Adapter."""
+
+ def __init__(
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
+ ):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ hidden_states_query_proj = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # IP-adapter
+ ip_query = hidden_states_query_proj
+ ip_attn_output = torch.zeros_like(hidden_states)
+
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ ):
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ current_ip_hidden_states = F.scaled_dot_product_attention(
+ ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
+ ip_attn_output += scale * current_ip_hidden_states
+
+ return hidden_states, encoder_hidden_states, ip_attn_output
+ else:
+ return hidden_states
+
+
+class CogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
@@ -2050,76 +2975,278 @@ class XFormersAttnAddedKVProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
- batch_size, sequence_length, _ = hidden_states.shape
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class XFormersAttnProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, key_tokens, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
+ if attention_mask is not None:
+ # expand our mask's singleton query_tokens dimension:
+ # [batch*heads, 1, key_tokens] ->
+ # [batch*heads, query_tokens, key_tokens]
+ # so that it can be added as a bias onto the attention scores that xformers computes:
+ # [batch*heads, query_tokens, key_tokens]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessorNPU:
+ r"""
+ Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
+ fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
+ not significant.
+
+ """
+
+ def __init__(self):
+ if not is_torch_npu_available():
+ raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+ attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
+ if attention_mask.dtype == torch.bool:
+ attention_mask = torch.logical_not(attention_mask.bool())
+ else:
+ attention_mask = attention_mask.bool()
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
- query = attn.to_q(hidden_states)
- query = attn.head_to_batch_dim(query)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- if not attn.only_cross_attention:
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ if query.dtype in (torch.float16, torch.bfloat16):
+ hidden_states = torch_npu.npu_fusion_attention(
+ query,
+ key,
+ value,
+ attn.heads,
+ input_layout="BNSD",
+ pse=None,
+ atten_mask=attention_mask,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0,
+ sync=False,
+ inner_precise=0,
+ )[0]
else:
- key = encoder_hidden_states_key_proj
- value = encoder_hidden_states_value_proj
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
- hidden_states = xformers.ops.memory_efficient_attention(
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
- )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
- hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
- hidden_states = hidden_states + residual
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
-class XFormersAttnProcessor:
+class AttnProcessor2_0:
r"""
- Processor for implementing memory efficient attention using xFormers.
-
- Args:
- attention_op (`Callable`, *optional*, defaults to `None`):
- The base
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
- operator.
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
- def __init__(self, attention_op: Optional[Callable] = None):
- self.attention_op = attention_op
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
@@ -2136,7 +3263,6 @@ def __call__(
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
-
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -2146,20 +3272,15 @@ def __call__(
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
- batch_size, key_tokens, _ = (
+ batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
- attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
- # expand our mask's singleton query_tokens dimension:
- # [batch*heads, 1, key_tokens] ->
- # [batch*heads, query_tokens, key_tokens]
- # so that it can be added as a bias onto the attention scores that xformers computes:
- # [batch*heads, query_tokens, key_tokens]
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
- _, query_tokens, _ = hidden_states.shape
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -2174,15 +3295,27 @@ def __call__(
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
- query = attn.head_to_batch_dim(query).contiguous()
- key = attn.head_to_batch_dim(key).contiguous()
- value = attn.head_to_batch_dim(value).contiguous()
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
- hidden_states = xformers.ops.memory_efficient_attention(
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
- hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@@ -2200,17 +3333,21 @@ def __call__(
return hidden_states
-class AttnProcessorNPU:
+class XLAFlashAttnProcessor2_0:
r"""
- Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
- fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
- not significant.
-
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
"""
- def __init__(self):
- if not is_torch_npu_available():
- raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
+ def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+ if is_torch_xla_version("<", "2.3"):
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
+ self.partition_spec = partition_spec
def __call__(
self,
@@ -2222,10 +3359,6 @@ def __call__(
*args,
**kwargs,
) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
-
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -2267,25 +3400,32 @@ def __call__(
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- atten_mask=attention_mask,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
+ if attention_mask is not None:
+ attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
+ # Convert mask to float and replace 0s with -inf and 1s with 0
+ attention_mask = (
+ attention_mask.float()
+ .masked_fill(attention_mask == 0, float("-inf"))
+ .masked_fill(attention_mask == 1, float(0.0))
+ )
+
+ # Apply attention mask to key
+ key = key + attention_mask
+ query /= math.sqrt(query.shape[3])
+ partition_spec = self.partition_spec if is_spmd() else None
+ hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
else:
- # TODO: add support for attn.scale when we move to Torch 2.1
+ logger.warning(
+ "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
+ )
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@@ -2301,17 +3441,117 @@ def __call__(
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
- if attn.residual_connection:
- hidden_states = hidden_states + residual
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class XLAFluxFlashAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
+ """
+
+ def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+ if is_torch_xla_version("<", "2.3"):
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
+ self.partition_spec = partition_spec
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ query /= math.sqrt(head_dim)
+ hidden_states = flash_attention(query, key, value, causal=False)
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
- hidden_states = hidden_states / attn.rescale_output_factor
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
- return hidden_states
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
-class AttnProcessor2_0:
+class MochiVaeAttnProcessor2_0:
r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ Attention processor used in Mochi VAE.
"""
def __init__(self):
@@ -2324,23 +3564,9 @@ def __call__(
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
- *args,
- **kwargs,
) -> torch.Tensor:
- if len(args) > 0 or kwargs.get("scale", None) is not None:
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
- deprecate("scale", "1.0.0", deprecation_message)
-
residual = hidden_states
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ is_single_frame = hidden_states.shape[1] == 1
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
@@ -2352,15 +3578,24 @@ def __call__(
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+ if is_single_frame:
+ hidden_states = attn.to_v(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -2369,7 +3604,6 @@ def __call__(
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@@ -2381,7 +3615,7 @@ def __call__(
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -2392,9 +3626,6 @@ def __call__(
# dropout
hidden_states = attn.to_out[1](hidden_states)
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
if attn.residual_connection:
hidden_states = hidden_states + residual
@@ -2479,8 +3710,10 @@ def __call__(
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
- key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
- value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
+ value = torch.repeat_interleave(
+ value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
+ )
if attn.norm_q is not None:
query = attn.norm_q(query)
@@ -3604,27 +4837,227 @@ class SpatialNorm(nn.Module):
The number of channels for the quantized vector as described in the paper.
"""
- def __init__(
- self,
- f_channels: int,
- zq_channels: int,
- ):
- super().__init__()
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
- self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
- self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ def __init__(
+ self,
+ f_channels: int,
+ zq_channels: int,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+ f_size = f.shape[-2:]
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+class IPAdapterAttnProcessor(nn.Module):
+ r"""
+ Attention processor for Multiple IP-Adapters.
+
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
+ The context length of the image features.
+ scale (`float` or List[`float`], defaults to 1.0):
+ the weight scale of image prompt.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+ self.num_tokens = num_tokens
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ scale: float = 1.0,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ):
+ residual = hidden_states
+
+ # separate ip_hidden_states from encoder_hidden_states
+ if encoder_hidden_states is not None:
+ if isinstance(encoder_hidden_states, tuple):
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+ else:
+ deprecation_message = (
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
+ )
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ [encoder_hidden_states[:, end_pos:, :]],
+ )
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if ip_adapter_masks is not None:
+ if not isinstance(ip_adapter_masks, List):
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
+ raise ValueError(
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+ f"({len(ip_hidden_states)})"
+ )
+ else:
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+ if mask is None:
+ continue
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+ raise ValueError(
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
+ "[1, num_images_for_ip_adapter, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if mask.shape[1] != ip_state.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
+ )
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of scales ({len(scale)}) at index {index}"
+ )
+ else:
+ ip_adapter_masks = [None] * len(self.scale)
+
+ # for ip-adapter
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
+ ):
+ skip = False
+ if isinstance(scale, list):
+ if all(s == 0 for s in scale):
+ skip = True
+ elif scale == 0:
+ skip = True
+ if not skip:
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale] * mask.shape[1]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
+
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
+
+ hidden_states = hidden_states + scale * current_ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
- def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
- f_size = f.shape[-2:]
- zq = F.interpolate(zq, size=f_size, mode="nearest")
- norm_f = self.norm_layer(f)
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
- return new_f
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
-class IPAdapterAttnProcessor(nn.Module):
+class IPAdapterAttnProcessor2_0(torch.nn.Module):
r"""
- Attention processor for Multiple IP-Adapters.
+ Attention processor for IP-Adapter for PyTorch 2.0.
Args:
hidden_size (`int`):
@@ -3633,13 +5066,18 @@ class IPAdapterAttnProcessor(nn.Module):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
- scale (`float` or List[`float`], defaults to 1.0):
+ scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
@@ -3700,7 +5138,12 @@ def __call__(
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -3715,13 +5158,22 @@ def __call__(
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
- query = attn.head_to_batch_dim(query)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
@@ -3735,6 +5187,8 @@ def __call__(
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+ if mask is None:
+ continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
@@ -3774,12 +5228,19 @@ def __call__(
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ _current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
@@ -3789,18 +5250,24 @@ def __call__(
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
-
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + scale * current_ip_hidden_states
@@ -3820,9 +5287,9 @@ def __call__(
return hidden_states
-class IPAdapterAttnProcessor2_0(torch.nn.Module):
+class IPAdapterXFormersAttnProcessor(torch.nn.Module):
r"""
- Attention processor for IP-Adapter for PyTorch 2.0.
+ Attention processor for IP-Adapter using xFormers.
Args:
hidden_size (`int`):
@@ -3833,18 +5300,26 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
The context length of the image features.
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
"""
- def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
+ def __init__(
+ self,
+ hidden_size,
+ cross_attention_dim=None,
+ num_tokens=(4,),
+ scale=1.0,
+ attention_op: Optional[Callable] = None,
+ ):
super().__init__()
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
+ self.attention_op = attention_op
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
@@ -3857,21 +5332,21 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale
self.scale = scale
self.to_k_ip = nn.ModuleList(
- [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
- [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__(
self,
attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- temb: Optional[torch.Tensor] = None,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
- ip_adapter_masks: Optional[torch.Tensor] = None,
+ ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
@@ -3906,9 +5381,14 @@ def __call__(
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+ # expand our mask's singleton query_tokens dimension:
+ # [batch*heads, 1, key_tokens] ->
+ # [batch*heads, query_tokens, key_tokens]
+ # so that it can be added as a bias onto the attention scores that xformers computes:
+ # [batch*heads, query_tokens, key_tokens]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -3923,131 +5403,291 @@ def __call__(
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if ip_hidden_states:
+ if ip_adapter_masks is not None:
+ if not isinstance(ip_adapter_masks, List):
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
+ raise ValueError(
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+ f"({len(ip_hidden_states)})"
+ )
+ else:
+ for index, (mask, scale, ip_state) in enumerate(
+ zip(ip_adapter_masks, self.scale, ip_hidden_states)
+ ):
+ if mask is None:
+ continue
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+ raise ValueError(
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
+ "[1, num_images_for_ip_adapter, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if mask.shape[1] != ip_state.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
+ )
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of scales ({len(scale)}) at index {index}"
+ )
+ else:
+ ip_adapter_masks = [None] * len(self.scale)
+
+ # for ip-adapter
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
+ ):
+ skip = False
+ if isinstance(scale, list):
+ if all(s == 0 for s in scale):
+ skip = True
+ elif scale == 0:
+ skip = True
+ if not skip:
+ if mask is not None:
+ mask = mask.to(torch.float16)
+ if not isinstance(scale, list):
+ scale = [scale] * mask.shape[1]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
+
+ _current_ip_hidden_states = xformers.ops.memory_efficient_attention(
+ query, ip_key, ip_value, op=self.attention_op
+ )
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
+
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
+
+ current_ip_hidden_states = xformers.ops.memory_efficient_attention(
+ query, ip_key, ip_value, op=self.attention_op
+ )
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
+
+ hidden_states = hidden_states + scale * current_ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
+ """
+ Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
+ additional image-based information and timestep embeddings.
+
+ Args:
+ hidden_size (`int`):
+ The number of hidden channels.
+ ip_hidden_states_dim (`int`):
+ The image feature dimension.
+ head_dim (`int`):
+ The number of head channels.
+ timesteps_emb_dim (`int`, defaults to 1280):
+ The number of input channels for timestep embedding.
+ scale (`float`, defaults to 0.5):
+ IP-Adapter scale.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ ip_hidden_states_dim: int,
+ head_dim: int,
+ timesteps_emb_dim: int = 1280,
+ scale: float = 0.5,
+ ):
+ super().__init__()
+
+ # To prevent circular import
+ from .normalization import AdaLayerNorm, RMSNorm
+
+ self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
+ self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
+ self.norm_q = RMSNorm(head_dim, 1e-6)
+ self.norm_k = RMSNorm(head_dim, 1e-6)
+ self.norm_ip_k = RMSNorm(head_dim, 1e-6)
+ self.scale = scale
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ ip_hidden_states: torch.FloatTensor = None,
+ temb: torch.FloatTensor = None,
+ ) -> torch.FloatTensor:
+ """
+ Perform the attention computation, integrating image features (if provided) and timestep embeddings.
+
+ If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
+
+ Args:
+ attn (`Attention`):
+ Attention instance.
+ hidden_states (`torch.FloatTensor`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor`, *optional*):
+ The encoder hidden states.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Attention mask.
+ ip_hidden_states (`torch.FloatTensor`, *optional*):
+ Image embeddings.
+ temb (`torch.FloatTensor`, *optional*):
+ Timestep embeddings.
+
+ Returns:
+ `torch.FloatTensor`: Output hidden states.
+ """
+ residual = hidden_states
+
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ img_query = query
+ img_key = key
+ img_value = value
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
- if ip_adapter_masks is not None:
- if not isinstance(ip_adapter_masks, List):
- # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
- ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
- if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
- raise ValueError(
- f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
- f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
- f"({len(ip_hidden_states)})"
- )
- else:
- for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
- if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
- raise ValueError(
- "Each element of the ip_adapter_masks array should be a tensor with shape "
- "[1, num_images_for_ip_adapter, height, width]."
- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
- )
- if mask.shape[1] != ip_state.shape[1]:
- raise ValueError(
- f"Number of masks ({mask.shape[1]}) does not match "
- f"number of ip images ({ip_state.shape[1]}) at index {index}"
- )
- if isinstance(scale, list) and not len(scale) == mask.shape[1]:
- raise ValueError(
- f"Number of masks ({mask.shape[1]}) does not match "
- f"number of scales ({len(scale)}) at index {index}"
- )
- else:
- ip_adapter_masks = [None] * len(self.scale)
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
- # for ip-adapter
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
- ):
- skip = False
- if isinstance(scale, list):
- if all(s == 0 for s in scale):
- skip = True
- elif scale == 0:
- skip = True
- if not skip:
- if mask is not None:
- if not isinstance(scale, list):
- scale = [scale] * mask.shape[1]
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
- current_num_images = mask.shape[1]
- for i in range(current_num_images):
- ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
- ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- _current_ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
+ if encoder_hidden_states is not None:
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
- _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
- _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
+ # IP Adapter
+ if self.scale != 0 and ip_hidden_states is not None:
+ # Norm image features
+ norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
- mask_downsample = IPAdapterMaskProcessor.downsample(
- mask[:, i, :, :],
- batch_size,
- _current_ip_hidden_states.shape[1],
- _current_ip_hidden_states.shape[2],
- )
+ # To k and v
+ ip_key = self.to_k_ip(norm_ip_hidden_states)
+ ip_value = self.to_v_ip(norm_ip_hidden_states)
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
- hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
- else:
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
+ # Reshape
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ # Norm
+ query = self.norm_q(img_query)
+ img_key = self.norm_k(img_key)
+ ip_key = self.norm_ip_k(ip_key)
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- current_ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
+ # cat img
+ key = torch.cat([img_key, ip_key], dim=2)
+ value = torch.cat([img_value, ip_value], dim=2)
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
- current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+ ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
- hidden_states = hidden_states + scale * current_ip_hidden_states
+ hidden_states = hidden_states + ip_hidden_states * self.scale
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
class PAGIdentitySelfAttnProcessor2_0:
@@ -4252,22 +5892,98 @@ def __call__(
return hidden_states
+class SanaMultiscaleAttnProcessor2_0:
+ r"""
+ Processor for implementing multiscale quadratic attention.
+ """
+
+ def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
+ height, width = hidden_states.shape[-2:]
+ if height * width > attn.attention_head_dim:
+ use_linear_attention = True
+ else:
+ use_linear_attention = False
+
+ residual = hidden_states
+
+ batch_size, _, height, width = list(hidden_states.size())
+ original_dtype = hidden_states.dtype
+
+ hidden_states = hidden_states.movedim(1, -1)
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ hidden_states = torch.cat([query, key, value], dim=3)
+ hidden_states = hidden_states.movedim(-1, 1)
+
+ multi_scale_qkv = [hidden_states]
+ for block in attn.to_qkv_multiscale:
+ multi_scale_qkv.append(block(hidden_states))
+
+ hidden_states = torch.cat(multi_scale_qkv, dim=1)
+
+ if use_linear_attention:
+ # for linear attention upcast hidden_states to float32
+ hidden_states = hidden_states.to(dtype=torch.float32)
+
+ hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
+
+ query, key, value = hidden_states.chunk(3, dim=2)
+ query = attn.nonlinearity(query)
+ key = attn.nonlinearity(key)
+
+ if use_linear_attention:
+ hidden_states = attn.apply_linear_attention(query, key, value)
+ hidden_states = hidden_states.to(dtype=original_dtype)
+ else:
+ hidden_states = attn.apply_quadratic_attention(query, key, value)
+
+ hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
+ hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if attn.norm_type == "rms_norm":
+ hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
+ else:
+ hidden_states = attn.norm_out(hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
class LoRAAttnProcessor:
+ r"""
+ Processor for implementing attention with LoRA.
+ """
+
def __init__(self):
pass
class LoRAAttnProcessor2_0:
+ r"""
+ Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
+ """
+
def __init__(self):
pass
class LoRAXFormersAttnProcessor:
+ r"""
+ Processor for implementing attention with LoRA using xFormers.
+ """
+
def __init__(self):
pass
class LoRAAttnAddedKVProcessor:
+ r"""
+ Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder.
+ """
+
def __init__(self):
pass
@@ -4283,6 +5999,170 @@ def __init__(self):
super().__init__()
+class SanaLinearAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product linear attention.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ original_dtype = hidden_states.dtype
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
+
+ query = F.relu(query)
+ key = F.relu(key)
+
+ query, key, value = query.float(), key.float(), value.float()
+
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
+ scores = torch.matmul(value, key)
+ hidden_states = torch.matmul(scores, query)
+
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
+ hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
+ hidden_states = hidden_states.to(original_dtype)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if original_dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states
+
+
+class PAGCFGSanaLinearAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product linear attention.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ original_dtype = hidden_states.dtype
+
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
+
+ query = attn.to_q(hidden_states_org)
+ key = attn.to_k(hidden_states_org)
+ value = attn.to_v(hidden_states_org)
+
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
+
+ query = F.relu(query)
+ key = F.relu(key)
+
+ query, key, value = query.float(), key.float(), value.float()
+
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
+ scores = torch.matmul(value, key)
+ hidden_states_org = torch.matmul(scores, query)
+
+ hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
+ hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
+ hidden_states_org = hidden_states_org.to(original_dtype)
+
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ # perturbed path (identity attention)
+ hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
+
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if original_dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states
+
+
+class PAGIdentitySanaLinearAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product linear attention.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ original_dtype = hidden_states.dtype
+
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
+
+ query = attn.to_q(hidden_states_org)
+ key = attn.to_k(hidden_states_org)
+ value = attn.to_v(hidden_states_org)
+
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
+
+ query = F.relu(query)
+ key = F.relu(key)
+
+ query, key, value = query.float(), key.float(), value.float()
+
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
+ scores = torch.matmul(value, key)
+ hidden_states_org = torch.matmul(scores, query)
+
+ if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states_org = hidden_states_org.float()
+
+ hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
+ hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
+ hidden_states_org = hidden_states_org.to(original_dtype)
+
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ # perturbed path (identity attention)
+ hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
+
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if original_dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states
+
+
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
@@ -4297,23 +6177,59 @@ def __init__(self):
SlicedAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
+ FluxIPAdapterJointAttnProcessor2_0,
)
AttentionProcessor = Union[
AttnProcessor,
- AttnProcessor2_0,
- FusedAttnProcessor2_0,
- XFormersAttnProcessor,
- SlicedAttnProcessor,
+ CustomDiffusionAttnProcessor,
AttnAddedKVProcessor,
- SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
+ JointAttnProcessor2_0,
+ PAGJointAttnProcessor2_0,
+ PAGCFGJointAttnProcessor2_0,
+ FusedJointAttnProcessor2_0,
+ AllegroAttnProcessor2_0,
+ AuraFlowAttnProcessor2_0,
+ FusedAuraFlowAttnProcessor2_0,
+ FluxAttnProcessor2_0,
+ FluxAttnProcessor2_0_NPU,
+ FusedFluxAttnProcessor2_0,
+ FusedFluxAttnProcessor2_0_NPU,
+ CogVideoXAttnProcessor2_0,
+ FusedCogVideoXAttnProcessor2_0,
XFormersAttnAddedKVProcessor,
- CustomDiffusionAttnProcessor,
+ XFormersAttnProcessor,
+ XLAFlashAttnProcessor2_0,
+ AttnProcessorNPU,
+ AttnProcessor2_0,
+ MochiVaeAttnProcessor2_0,
+ MochiAttnProcessor2_0,
+ StableAudioAttnProcessor2_0,
+ HunyuanAttnProcessor2_0,
+ FusedHunyuanAttnProcessor2_0,
+ PAGHunyuanAttnProcessor2_0,
+ PAGCFGHunyuanAttnProcessor2_0,
+ LuminaAttnProcessor2_0,
+ FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
- PAGCFGIdentitySelfAttnProcessor2_0,
+ SlicedAttnProcessor,
+ SlicedAttnAddedKVProcessor,
+ SanaLinearAttnProcessor2_0,
+ PAGCFGSanaLinearAttnProcessor2_0,
+ PAGIdentitySanaLinearAttnProcessor2_0,
+ SanaMultiscaleLinearAttention,
+ SanaMultiscaleAttnProcessor2_0,
+ SanaMultiscaleAttentionProjection,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+ IPAdapterXFormersAttnProcessor,
+ SD3IPAdapterJointAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
- PAGCFGHunyuanAttnProcessor2_0,
- PAGHunyuanAttnProcessor2_0,
+ PAGCFGIdentitySelfAttnProcessor2_0,
+ LoRAAttnProcessor,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnAddedKVProcessor,
]
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index ccf4552b2a5e..f8f49ce4c797 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -1,7 +1,14 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
+from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL
+from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
+from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
+from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
+from .autoencoder_kl_magvit import AutoencoderKLMagvit
+from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
+from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
index 3f4d46557bf7..c643dcc72a34 100644
--- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
+ _skip_layerwise_casting_patterns = ["decoder"]
+
@register_to_config
def __init__(
self,
diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py
new file mode 100644
index 000000000000..9146aa5c7c6c
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_dc.py
@@ -0,0 +1,722 @@
+# Copyright 2024 MIT, Tsinghua University, NVIDIA CORPORATION and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..attention_processor import SanaMultiscaleLinearAttention
+from ..modeling_utils import ModelMixin
+from ..normalization import RMSNorm, get_normalization
+from ..transformers.sana_transformer import GLUMBConv
+from .vae import DecoderOutput, EncoderOutput
+
+
+class ResBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ norm_type: str = "batch_norm",
+ act_fn: str = "relu6",
+ ) -> None:
+ super().__init__()
+
+ self.norm_type = norm_type
+
+ self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
+ self.norm = get_normalization(norm_type, out_channels)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.norm_type == "rms_norm":
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
+ else:
+ hidden_states = self.norm(hidden_states)
+
+ return hidden_states + residual
+
+
+class EfficientViTBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ mult: float = 1.0,
+ attention_head_dim: int = 32,
+ qkv_multiscales: Tuple[int, ...] = (5,),
+ norm_type: str = "batch_norm",
+ ) -> None:
+ super().__init__()
+
+ self.attn = SanaMultiscaleLinearAttention(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ mult=mult,
+ attention_head_dim=attention_head_dim,
+ norm_type=norm_type,
+ kernel_sizes=qkv_multiscales,
+ residual_connection=True,
+ )
+
+ self.conv_out = GLUMBConv(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ norm_type="rms_norm",
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.attn(x)
+ x = self.conv_out(x)
+ return x
+
+
+def get_block(
+ block_type: str,
+ in_channels: int,
+ out_channels: int,
+ attention_head_dim: int,
+ norm_type: str,
+ act_fn: str,
+ qkv_mutliscales: Tuple[int] = (),
+):
+ if block_type == "ResBlock":
+ block = ResBlock(in_channels, out_channels, norm_type, act_fn)
+
+ elif block_type == "EfficientViTBlock":
+ block = EfficientViTBlock(
+ in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
+ )
+
+ else:
+ raise ValueError(f"Block with {block_type=} is not supported.")
+
+ return block
+
+
+class DCDownBlock2d(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
+ super().__init__()
+
+ self.downsample = downsample
+ self.factor = 2
+ self.stride = 1 if downsample else 2
+ self.group_size = in_channels * self.factor**2 // out_channels
+ self.shortcut = shortcut
+
+ out_ratio = self.factor**2
+ if downsample:
+ assert out_channels % out_ratio == 0
+ out_channels = out_channels // out_ratio
+
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=self.stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ x = self.conv(hidden_states)
+ if self.downsample:
+ x = F.pixel_unshuffle(x, self.factor)
+
+ if self.shortcut:
+ y = F.pixel_unshuffle(hidden_states, self.factor)
+ y = y.unflatten(1, (-1, self.group_size))
+ y = y.mean(dim=2)
+ hidden_states = x + y
+ else:
+ hidden_states = x
+
+ return hidden_states
+
+
+class DCUpBlock2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ interpolate: bool = False,
+ shortcut: bool = True,
+ interpolation_mode: str = "nearest",
+ ) -> None:
+ super().__init__()
+
+ self.interpolate = interpolate
+ self.interpolation_mode = interpolation_mode
+ self.shortcut = shortcut
+ self.factor = 2
+ self.repeats = out_channels * self.factor**2 // in_channels
+
+ out_ratio = self.factor**2
+
+ if not interpolate:
+ out_channels = out_channels * out_ratio
+
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.interpolate:
+ x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
+ x = self.conv(x)
+ else:
+ x = self.conv(hidden_states)
+ x = F.pixel_shuffle(x, self.factor)
+
+ if self.shortcut:
+ y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
+ y = F.pixel_shuffle(y, self.factor)
+ hidden_states = x + y
+ else:
+ hidden_states = x
+
+ return hidden_states
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ latent_channels: int,
+ attention_head_dim: int = 32,
+ block_type: Union[str, Tuple[str]] = "ResBlock",
+ block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
+ layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
+ qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
+ downsample_block_type: str = "pixel_unshuffle",
+ out_shortcut: bool = True,
+ ):
+ super().__init__()
+
+ num_blocks = len(block_out_channels)
+
+ if isinstance(block_type, str):
+ block_type = (block_type,) * num_blocks
+
+ if layers_per_block[0] > 0:
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ else:
+ self.conv_in = DCDownBlock2d(
+ in_channels=in_channels,
+ out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
+ downsample=downsample_block_type == "pixel_unshuffle",
+ shortcut=False,
+ )
+
+ down_blocks = []
+ for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
+ down_block_list = []
+
+ for _ in range(num_layers):
+ block = get_block(
+ block_type[i],
+ out_channel,
+ out_channel,
+ attention_head_dim=attention_head_dim,
+ norm_type="rms_norm",
+ act_fn="silu",
+ qkv_mutliscales=qkv_multiscales[i],
+ )
+ down_block_list.append(block)
+
+ if i < num_blocks - 1 and num_layers > 0:
+ downsample_block = DCDownBlock2d(
+ in_channels=out_channel,
+ out_channels=block_out_channels[i + 1],
+ downsample=downsample_block_type == "pixel_unshuffle",
+ shortcut=True,
+ )
+ down_block_list.append(downsample_block)
+
+ down_blocks.append(nn.Sequential(*down_block_list))
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
+
+ self.out_shortcut = out_shortcut
+ if out_shortcut:
+ self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states)
+
+ if self.out_shortcut:
+ x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
+ x = x.mean(dim=2)
+ hidden_states = self.conv_out(hidden_states) + x
+ else:
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ latent_channels: int,
+ attention_head_dim: int = 32,
+ block_type: Union[str, Tuple[str]] = "ResBlock",
+ block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
+ layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
+ qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
+ norm_type: Union[str, Tuple[str]] = "rms_norm",
+ act_fn: Union[str, Tuple[str]] = "silu",
+ upsample_block_type: str = "pixel_shuffle",
+ in_shortcut: bool = True,
+ ):
+ super().__init__()
+
+ num_blocks = len(block_out_channels)
+
+ if isinstance(block_type, str):
+ block_type = (block_type,) * num_blocks
+ if isinstance(norm_type, str):
+ norm_type = (norm_type,) * num_blocks
+ if isinstance(act_fn, str):
+ act_fn = (act_fn,) * num_blocks
+
+ self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
+
+ self.in_shortcut = in_shortcut
+ if in_shortcut:
+ self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
+
+ up_blocks = []
+ for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
+ up_block_list = []
+
+ if i < num_blocks - 1 and num_layers > 0:
+ upsample_block = DCUpBlock2d(
+ block_out_channels[i + 1],
+ out_channel,
+ interpolate=upsample_block_type == "interpolate",
+ shortcut=True,
+ )
+ up_block_list.append(upsample_block)
+
+ for _ in range(num_layers):
+ block = get_block(
+ block_type[i],
+ out_channel,
+ out_channel,
+ attention_head_dim=attention_head_dim,
+ norm_type=norm_type[i],
+ act_fn=act_fn[i],
+ qkv_mutliscales=qkv_multiscales[i],
+ )
+ up_block_list.append(block)
+
+ up_blocks.insert(0, nn.Sequential(*up_block_list))
+
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
+
+ self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
+ self.conv_act = nn.ReLU()
+ self.conv_out = None
+
+ if layers_per_block[0] > 0:
+ self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1)
+ else:
+ self.conv_out = DCUpBlock2d(
+ channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.in_shortcut:
+ x = hidden_states.repeat_interleave(
+ self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
+ )
+ hidden_states = self.conv_in(hidden_states) + x
+ else:
+ hidden_states = self.conv_in(hidden_states)
+
+ for up_block in reversed(self.up_blocks):
+ hidden_states = up_block(hidden_states)
+
+ hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
+ [SANA](https://arxiv.org/abs/2410.10629).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Args:
+ in_channels (`int`, defaults to `3`):
+ The number of input channels in samples.
+ latent_channels (`int`, defaults to `32`):
+ The number of channels in the latent space representation.
+ encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
+ The type(s) of block to use in the encoder.
+ decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
+ The type(s) of block to use in the decoder.
+ encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
+ The number of output channels for each block in the encoder.
+ decoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
+ The number of output channels for each block in the decoder.
+ encoder_layers_per_block (`Tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
+ The number of layers per block in the encoder.
+ decoder_layers_per_block (`Tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
+ The number of layers per block in the decoder.
+ encoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
+ Multi-scale configurations for the encoder's QKV (query-key-value) transformations.
+ decoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
+ Multi-scale configurations for the decoder's QKV (query-key-value) transformations.
+ upsample_block_type (`str`, defaults to `"pixel_shuffle"`):
+ The type of block to use for upsampling in the decoder.
+ downsample_block_type (`str`, defaults to `"pixel_unshuffle"`):
+ The type of block to use for downsampling in the encoder.
+ decoder_norm_types (`Union[str, Tuple[str]]`, defaults to `"rms_norm"`):
+ The normalization type(s) to use in the decoder.
+ decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
+ The activation function(s) to use in the decoder.
+ scaling_factor (`float`, defaults to `1.0`):
+ The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
+ space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
+ z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back
+ to the original scale with the formula: `z = 1 / scaling_factor * z`.
+ """
+
+ _supports_gradient_checkpointing = False
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ latent_channels: int = 32,
+ attention_head_dim: int = 32,
+ encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
+ decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
+ encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
+ decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
+ encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
+ decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
+ upsample_block_type: str = "pixel_shuffle",
+ downsample_block_type: str = "pixel_unshuffle",
+ decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
+ decoder_act_fns: Union[str, Tuple[str]] = "silu",
+ scaling_factor: float = 1.0,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ latent_channels=latent_channels,
+ attention_head_dim=attention_head_dim,
+ block_type=encoder_block_types,
+ block_out_channels=encoder_block_out_channels,
+ layers_per_block=encoder_layers_per_block,
+ qkv_multiscales=encoder_qkv_multiscales,
+ downsample_block_type=downsample_block_type,
+ )
+ self.decoder = Decoder(
+ in_channels=in_channels,
+ latent_channels=latent_channels,
+ attention_head_dim=attention_head_dim,
+ block_type=decoder_block_types,
+ block_out_channels=decoder_block_out_channels,
+ layers_per_block=decoder_layers_per_block,
+ qkv_multiscales=decoder_qkv_multiscales,
+ norm_type=decoder_norm_types,
+ act_fn=decoder_act_fns,
+ upsample_block_type=upsample_block_type,
+ )
+
+ self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
+ self.temporal_compression_ratio = 1
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 512
+ self.tile_sample_min_width = 512
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 448
+ self.tile_sample_stride_width = 448
+
+ self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled AE decoding. When this option is enabled, the AE will split the input tensor into tiles to compute
+ decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
+ decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x, return_dict=False)[0]
+
+ encoded = self.encoder(x)
+
+ return encoded
+
+ @apply_forward_hook
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, defaults to `True`):
+ Whether to return a [`~models.vae.EncoderOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.vae.EncoderOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ encoded = torch.cat(encoded_slices)
+ else:
+ encoded = self._encode(x)
+
+ if not return_dict:
+ return (encoded,)
+ return EncoderOutput(latent=encoded)
+
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = z.shape
+
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=False)[0]
+
+ decoded = self.decoder(z)
+
+ return decoded
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.size(0) > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z)
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
+ batch_size, num_channels, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, x.shape[2], self.tile_sample_stride_height):
+ row = []
+ for j in range(0, x.shape[3], self.tile_sample_stride_width):
+ tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
+ if (
+ tile.shape[2] % self.spatial_compression_ratio != 0
+ or tile.shape[3] % self.spatial_compression_ratio != 0
+ ):
+ pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
+ pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
+ tile = F.pad(tile, (0, pad_w, 0, pad_h))
+ tile = self.encoder(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
+
+ if not return_dict:
+ return (encoded,)
+ return EncoderOutput(latent=encoded)
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, height, width = z.shape
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ decoded = torch.cat(result_rows, dim=2)
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
+ encoded = self.encode(sample, return_dict=False)[0]
+ decoded = self.decode(encoded, return_dict=False)[0]
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py
index 99a7da4a0b6f..357df0c31087 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl.py
@@ -17,6 +17,7 @@
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
@@ -34,7 +35,7 @@
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
-class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -137,10 +138,6 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (Encoder, Decoder)):
- module.gradient_checkpointing = value
-
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
new file mode 100644
index 000000000000..a76277366c09
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
@@ -0,0 +1,1131 @@
+# Copyright 2024 The RhymesAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils.accelerate_utils import apply_forward_hook
+from ..attention_processor import Attention, SpatialNorm
+from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
+from ..downsampling import Downsample2D
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from ..resnet import ResnetBlock2D
+from ..upsampling import Upsample2D
+
+
+class AllegroTemporalConvLayer(nn.Module):
+ r"""
+ Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from:
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: Optional[int] = None,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ up_sample: bool = False,
+ down_sample: bool = False,
+ stride: int = 1,
+ ) -> None:
+ super().__init__()
+
+ out_dim = out_dim or in_dim
+ pad_h = pad_w = int((stride - 1) * 0.5)
+ pad_t = 0
+
+ self.down_sample = down_sample
+ self.up_sample = up_sample
+
+ if down_sample:
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(norm_num_groups, in_dim),
+ nn.SiLU(),
+ nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)),
+ )
+ elif up_sample:
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(norm_num_groups, in_dim),
+ nn.SiLU(),
+ nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)),
+ )
+ else:
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(norm_num_groups, in_dim),
+ nn.SiLU(),
+ nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
+ )
+ self.conv2 = nn.Sequential(
+ nn.GroupNorm(norm_num_groups, out_dim),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
+ )
+ self.conv3 = nn.Sequential(
+ nn.GroupNorm(norm_num_groups, out_dim),
+ nn.SiLU(),
+ nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
+ )
+ self.conv4 = nn.Sequential(
+ nn.GroupNorm(norm_num_groups, out_dim),
+ nn.SiLU(),
+ nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
+ )
+
+ @staticmethod
+ def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2)
+ hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2)
+ return hidden_states
+
+ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor:
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+
+ if self.down_sample:
+ identity = hidden_states[:, :, ::2]
+ elif self.up_sample:
+ identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
+ else:
+ identity = hidden_states
+
+ if self.down_sample or self.up_sample:
+ hidden_states = self.conv1(hidden_states)
+ else:
+ hidden_states = self._pad_temporal_dim(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ if self.up_sample:
+ hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3)
+
+ hidden_states = self._pad_temporal_dim(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ hidden_states = self._pad_temporal_dim(hidden_states)
+ hidden_states = self.conv3(hidden_states)
+
+ hidden_states = self._pad_temporal_dim(hidden_states)
+ hidden_states = self.conv4(hidden_states)
+
+ hidden_states = identity + hidden_states
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+
+ return hidden_states
+
+
+class AllegroDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ spatial_downsample: bool = True,
+ temporal_downsample: bool = False,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+
+ resnets = []
+ temp_convs = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ AllegroTemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+
+ if temporal_downsample:
+ self.temp_convs_down = AllegroTemporalConvLayer(
+ out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3
+ )
+ self.add_temp_downsample = temporal_downsample
+
+ if spatial_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
+
+ if self.add_temp_downsample:
+ hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ return hidden_states
+
+
+class AllegroUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ spatial_upsample: bool = True,
+ temporal_upsample: bool = False,
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+
+ resnets = []
+ temp_convs = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ AllegroTemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+
+ self.add_temp_upsample = temporal_upsample
+ if temporal_upsample:
+ self.temp_conv_up = AllegroTemporalConvLayer(
+ out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3
+ )
+
+ if spatial_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
+
+ if self.add_temp_upsample:
+ hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ return hidden_states
+
+
+class AllegroMidBlock3DConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ add_attention: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ temp_convs = [
+ AllegroTemporalConvLayer(
+ in_channels,
+ in_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ ]
+ attentions = []
+
+ if attention_head_dim is None:
+ attention_head_dim = in_channels
+
+ for _ in range(num_layers):
+ if add_attention:
+ attentions.append(
+ Attention(
+ in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+ else:
+ attentions.append(None)
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ temp_convs.append(
+ AllegroTemporalConvLayer(
+ in_channels,
+ in_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+ self.attentions = nn.ModuleList(attentions)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ hidden_states = self.resnets[0](hidden_states, temb=None)
+
+ hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size)
+
+ for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]):
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = temp_conv(hidden_states, batch_size=batch_size)
+
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ return hidden_states
+
+
+class AllegroEncoder3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str, ...] = (
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False],
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ double_z: bool = True,
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.temp_conv_in = nn.Conv3d(
+ in_channels=block_out_channels[0],
+ out_channels=block_out_channels[0],
+ kernel_size=(3, 1, 1),
+ padding=(1, 0, 0),
+ )
+
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ if down_block_type == "AllegroDownBlock3D":
+ down_block = AllegroDownBlock3D(
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ spatial_downsample=not is_final_block,
+ temporal_downsample=temporal_downsample_blocks[i],
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ )
+ else:
+ raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`")
+
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = AllegroMidBlock3DConv(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ )
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+
+ self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
+ batch_size = sample.shape[0]
+
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ sample = self.conv_in(sample)
+
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ residual = sample
+ sample = self.temp_conv_in(sample)
+ sample = sample + residual
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ # Down blocks
+ for down_block in self.down_blocks:
+ sample = self._gradient_checkpointing_func(down_block, sample)
+
+ # Mid block
+ sample = self._gradient_checkpointing_func(self.mid_block, sample)
+ else:
+ # Down blocks
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # Mid block
+ sample = self.mid_block(sample)
+
+ # Post process
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ residual = sample
+ sample = self.temp_conv_out(sample)
+ sample = sample + residual
+
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ sample = self.conv_out(sample)
+
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ return sample
+
+
+class AllegroDecoder3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 4,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = (
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ ),
+ temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False],
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ norm_type: str = "group", # group, spatial
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[-1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0))
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ temb_channels = in_channels if norm_type == "spatial" else None
+
+ # mid
+ self.mid_block = AllegroMidBlock3DConv(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=temb_channels,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ if up_block_type == "AllegroUpBlock3D":
+ up_block = AllegroUpBlock3D(
+ num_layers=layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ spatial_upsample=not is_final_block,
+ temporal_upsample=temporal_upsample_blocks[i],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ temb_channels=temb_channels,
+ resnet_time_scale_shift=norm_type,
+ )
+ else:
+ raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`")
+
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_type == "spatial":
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
+ else:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+
+ self.conv_act = nn.SiLU()
+
+ self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
+ batch_size = sample.shape[0]
+
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ sample = self.conv_in(sample)
+
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ residual = sample
+ sample = self.temp_conv_in(sample)
+ sample = sample + residual
+
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ # Mid block
+ sample = self._gradient_checkpointing_func(self.mid_block, sample)
+
+ # Up blocks
+ for up_block in self.up_blocks:
+ sample = self._gradient_checkpointing_func(up_block, sample)
+
+ else:
+ # Mid block
+ sample = self.mid_block(sample)
+ sample = sample.to(upscale_dtype)
+
+ # Up blocks
+ for up_block in self.up_blocks:
+ sample = up_block(sample)
+
+ # Post process
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ residual = sample
+ sample = self.temp_conv_out(sample)
+ sample = sample + residual
+
+ sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ sample = self.conv_out(sample)
+
+ sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ return sample
+
+
+class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
+ [Allegro](https://github.com/rhymes-ai/Allegro).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, defaults to `3`):
+ Number of channels in the input image.
+ out_channels (int, defaults to `3`):
+ Number of channels in the output.
+ down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`):
+ Tuple of strings denoting which types of down blocks to use.
+ up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`):
+ Tuple of strings denoting which types of up blocks to use.
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
+ Tuple of integers denoting number of output channels in each block.
+ temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`):
+ Tuple of booleans denoting which blocks to enable temporal downsampling in.
+ latent_channels (`int`, defaults to `4`):
+ Number of channels in latents.
+ layers_per_block (`int`, defaults to `2`):
+ Number of resnet or attention or temporal convolution layers per down/up block.
+ act_fn (`str`, defaults to `"silu"`):
+ The activation function to use.
+ norm_num_groups (`int`, defaults to `32`):
+ Number of groups to use in normalization layers.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ Ratio by which temporal dimension of samples are compressed.
+ sample_size (`int`, defaults to `320`):
+ Default latent size.
+ scaling_factor (`float`, defaults to `0.13235`):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str, ...] = (
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ ),
+ up_block_types: Tuple[str, ...] = (
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False),
+ temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False),
+ latent_channels: int = 4,
+ layers_per_block: int = 2,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ temporal_compression_ratio: float = 4,
+ sample_size: int = 320,
+ scaling_factor: float = 0.13,
+ force_upcast: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = AllegroEncoder3D(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ temporal_downsample_blocks=temporal_downsample_blocks,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ )
+ self.decoder = AllegroDecoder3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ temporal_upsample_blocks=temporal_upsample_blocks,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ )
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
+
+ # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need
+ # to use a specific parameter here or in other VAEs.
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
+ self.tile_overlap_t = 8
+ self.tile_overlap_h = 120
+ self.tile_overlap_w = 80
+ sample_frames = 24
+
+ self.kernel = (sample_frames, sample_size, sample_size)
+ self.stride = (
+ sample_frames - self.tile_overlap_t,
+ sample_size - self.tile_overlap_h,
+ sample_size - self.tile_overlap_w,
+ )
+
+ def enable_tiling(self) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.use_tiling = True
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ # TODO(aryan)
+ # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ if self.use_tiling:
+ return self.tiled_encode(x)
+
+ raise NotImplementedError("Encoding without tiling has not been implemented yet.")
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of videos into latents.
+
+ Args:
+ x (`torch.Tensor`):
+ Input batch of videos.
+ return_dict (`bool`, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
+ # TODO(aryan): refactor tiling implementation
+ # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
+ if self.use_tiling:
+ return self.tiled_decode(z)
+
+ raise NotImplementedError("Decoding without tiling has not been implemented yet.")
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Decode a batch of videos.
+
+ Args:
+ z (`torch.Tensor`):
+ Input batch of latent vectors.
+ return_dict (`bool`, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z)
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ local_batch_size = 1
+ rs = self.spatial_compression_ratio
+ rt = self.config.temporal_compression_ratio
+
+ batch_size, num_channels, num_frames, height, width = x.shape
+
+ output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1
+ output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1
+ output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1
+
+ count = 0
+ output_latent = x.new_zeros(
+ (
+ output_num_frames * output_height * output_width,
+ 2 * self.config.latent_channels,
+ self.kernel[0] // rt,
+ self.kernel[1] // rs,
+ self.kernel[2] // rs,
+ )
+ )
+ vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]))
+
+ for i in range(output_num_frames):
+ for j in range(output_height):
+ for k in range(output_width):
+ n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
+ h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
+ w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
+
+ video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
+ vae_batch_input[count % local_batch_size] = video_cube
+
+ if (
+ count % local_batch_size == local_batch_size - 1
+ or count == output_num_frames * output_height * output_width - 1
+ ):
+ latent = self.encoder(vae_batch_input)
+
+ if (
+ count == output_num_frames * output_height * output_width - 1
+ and count % local_batch_size != local_batch_size - 1
+ ):
+ output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1]
+ else:
+ output_latent[count - local_batch_size + 1 : count + 1] = latent
+
+ vae_batch_input = x.new_zeros(
+ (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])
+ )
+
+ count += 1
+
+ latent = x.new_zeros(
+ (batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs)
+ )
+ output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
+ output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
+ output_overlap = (
+ output_kernel[0] - output_stride[0],
+ output_kernel[1] - output_stride[1],
+ output_kernel[2] - output_stride[2],
+ )
+
+ for i in range(output_num_frames):
+ n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0]
+ for j in range(output_height):
+ h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1]
+ for k in range(output_width):
+ w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2]
+ latent_mean = _prepare_for_blend(
+ (i, output_num_frames, output_overlap[0]),
+ (j, output_height, output_overlap[1]),
+ (k, output_width, output_overlap[2]),
+ output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0),
+ )
+ latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean
+
+ latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ latent = self.quant_conv(latent)
+ latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ return latent
+
+ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
+ local_batch_size = 1
+ rs = self.spatial_compression_ratio
+ rt = self.config.temporal_compression_ratio
+
+ latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs
+ latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs
+
+ batch_size, num_channels, num_frames, height, width = z.shape
+
+ ## post quant conv (a mapping)
+ z = z.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ z = self.post_quant_conv(z)
+ z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+
+ output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1
+ output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1
+ output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1
+
+ count = 0
+ decoded_videos = z.new_zeros(
+ (
+ output_num_frames * output_height * output_width,
+ self.config.out_channels,
+ self.kernel[0],
+ self.kernel[1],
+ self.kernel[2],
+ )
+ )
+ vae_batch_input = z.new_zeros(
+ (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
+ )
+
+ for i in range(output_num_frames):
+ for j in range(output_height):
+ for k in range(output_width):
+ n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0]
+ h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1]
+ w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2]
+
+ current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end]
+ vae_batch_input[count % local_batch_size] = current_latent
+
+ if (
+ count % local_batch_size == local_batch_size - 1
+ or count == output_num_frames * output_height * output_width - 1
+ ):
+ current_video = self.decoder(vae_batch_input)
+
+ if (
+ count == output_num_frames * output_height * output_width - 1
+ and count % local_batch_size != local_batch_size - 1
+ ):
+ decoded_videos[count - count % local_batch_size :] = current_video[
+ : count % local_batch_size + 1
+ ]
+ else:
+ decoded_videos[count - local_batch_size + 1 : count + 1] = current_video
+
+ vae_batch_input = z.new_zeros(
+ (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2])
+ )
+
+ count += 1
+
+ video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs))
+ video_overlap = (
+ self.kernel[0] - self.stride[0],
+ self.kernel[1] - self.stride[1],
+ self.kernel[2] - self.stride[2],
+ )
+
+ for i in range(output_num_frames):
+ n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0]
+ for j in range(output_height):
+ h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1]
+ for k in range(output_width):
+ w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2]
+ out_video_blend = _prepare_for_blend(
+ (i, output_num_frames, video_overlap[0]),
+ (j, output_height, video_overlap[1]),
+ (k, output_width, video_overlap[2]),
+ decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0),
+ )
+ video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend
+
+ video = video.permute(0, 2, 1, 3, 4).contiguous()
+ return video
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ generator (`torch.Generator`, *optional*):
+ PyTorch random number generator.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+
+def _prepare_for_blend(n_param, h_param, w_param, x):
+ # TODO(aryan): refactor
+ n, n_max, overlap_n = n_param
+ h, h_max, overlap_h = h_param
+ w, w_max, overlap_w = w_param
+ if overlap_n > 0:
+ if n > 0: # the head overlap part decays from 0 to 1
+ x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * (
+ torch.arange(0, overlap_n).float().to(x.device) / overlap_n
+ ).reshape(overlap_n, 1, 1)
+ if n < n_max - 1: # the tail overlap part decays from 1 to 0
+ x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * (
+ 1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n
+ ).reshape(overlap_n, 1, 1)
+ if h > 0:
+ x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * (
+ torch.arange(0, overlap_h).float().to(x.device) / overlap_h
+ ).reshape(overlap_h, 1)
+ if h < h_max - 1:
+ x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * (
+ 1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h
+ ).reshape(overlap_h, 1)
+ if w > 0:
+ x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * (
+ torch.arange(0, overlap_w).float().to(x.device) / overlap_w
+ )
+ if w < w_max - 1:
+ x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * (
+ 1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w
+ )
+ return x
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
index 68b49d72acc5..e2b26396899f 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -94,20 +94,23 @@ def __init__(
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
- self.pad_mode = pad_mode
- time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
- height_pad = height_kernel_size // 2
- width_pad = width_kernel_size // 2
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
+ time_pad = time_kernel_size - 1
+ height_pad = (height_kernel_size - 1) // 2
+ width_pad = (width_kernel_size - 1) // 2
+ self.pad_mode = pad_mode
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
+ self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
self.temporal_dim = 2
self.time_kernel_size = time_kernel_size
- stride = (stride, 1, 1)
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = CogVideoXSafeConv3d(
in_channels=in_channels,
@@ -115,23 +118,29 @@ def __init__(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
+ padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
+ padding_mode="zeros",
)
def fake_context_parallel_forward(
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
- kernel_size = self.time_kernel_size
- if kernel_size > 1:
- cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
- inputs = torch.cat(cached_inputs + [inputs], dim=2)
+ if self.pad_mode == "replicate":
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
+ else:
+ kernel_size = self.time_kernel_size
+ if kernel_size > 1:
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
- conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
- padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
- inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
+ if self.pad_mode == "replicate":
+ conv_cache = None
+ else:
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
output = self.conv(inputs)
return output, conv_cache
@@ -412,20 +421,13 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def create_forward(*inputs):
- return module(*inputs)
-
- return create_forward
-
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ resnet,
hidden_states,
temb,
zq,
- conv_cache=conv_cache.get(conv_cache_key),
+ conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -514,16 +516,9 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def create_forward(*inputs):
- return module(*inputs)
-
- return create_forward
-
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -628,20 +623,13 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def create_forward(*inputs):
- return module(*inputs)
-
- return create_forward
-
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ resnet,
hidden_states,
temb,
zq,
- conv_cache=conv_cache.get(conv_cache_key),
+ conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -765,39 +753,32 @@ def forward(
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
- create_custom_forward(down_block),
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ down_block,
hidden_states,
temb,
None,
- conv_cache=conv_cache.get(conv_cache_key),
+ conv_cache.get(conv_cache_key),
)
# 2. Mid
- hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block),
+ hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
+ self.mid_block,
hidden_states,
temb,
None,
- conv_cache=conv_cache.get("mid_block"),
+ conv_cache.get("mid_block"),
)
else:
# 1. Down
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = down_block(
- hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
+ hidden_states, temb, None, conv_cache.get(conv_cache_key)
)
# 2. Mid
@@ -931,32 +912,25 @@ def forward(
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
# 1. Mid
- hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block),
+ hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func(
+ self.mid_block,
hidden_states,
temb,
sample,
- conv_cache=conv_cache.get("mid_block"),
+ conv_cache.get("mid_block"),
)
# 2. Up
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block),
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ up_block,
hidden_states,
temb,
sample,
- conv_cache=conv_cache.get(conv_cache_key),
+ conv_cache.get(conv_cache_key),
)
else:
# 1. Mid
@@ -1049,6 +1023,7 @@ def __init__(
force_upcast: float = True,
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
+ invert_scale_latents: bool = False,
):
super().__init__()
@@ -1113,10 +1088,6 @@ def __init__(
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
- module.gradient_checkpointing = value
-
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
@@ -1467,7 +1438,7 @@ def forward(
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
- dec = self.decode(z)
+ dec = self.decode(z).sample
if not return_dict:
return (dec,)
- return dec
+ return DecoderOutput(sample=dec)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
new file mode 100644
index 000000000000..089e641d8852
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
@@ -0,0 +1,1096 @@
+# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..attention_processor import Attention
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_causal_attention_mask(
+ num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
+) -> torch.Tensor:
+ indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
+ indices_blocks = indices.repeat_interleave(height_width)
+ x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
+ mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
+
+ if batch_size is not None:
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
+ return mask
+
+
+class HunyuanVideoCausalConv3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ padding: Union[int, Tuple[int, int, int]] = 0,
+ dilation: Union[int, Tuple[int, int, int]] = 1,
+ bias: bool = True,
+ pad_mode: str = "replicate",
+ ) -> None:
+ super().__init__()
+
+ kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+
+ self.pad_mode = pad_mode
+ self.time_causal_padding = (
+ kernel_size[0] // 2,
+ kernel_size[0] // 2,
+ kernel_size[1] // 2,
+ kernel_size[1] // 2,
+ kernel_size[2] - 1,
+ 0,
+ )
+
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(hidden_states)
+
+
+class HunyuanVideoUpsampleCausal3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ kernel_size: int = 3,
+ stride: int = 1,
+ bias: bool = True,
+ upsample_factor: Tuple[float, float, float] = (2, 2, 2),
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+ self.upsample_factor = upsample_factor
+
+ self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_frames = hidden_states.size(2)
+
+ first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)
+ first_frame = F.interpolate(
+ first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest"
+ ).unsqueeze(2)
+
+ if num_frames > 1:
+ # See: https://github.com/pytorch/pytorch/issues/81665
+ # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate
+ # is fixed, this will raise either a runtime error, or fail silently with bad outputs.
+ # If you are encountering an error here, make sure to try running encoding/decoding with
+ # `vae.enable_tiling()` first. If that doesn't work, open an issue at:
+ # https://github.com/huggingface/diffusers/issues
+ other_frames = other_frames.contiguous()
+ other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest")
+ hidden_states = torch.cat((first_frame, other_frames), dim=2)
+ else:
+ hidden_states = first_frame
+
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class HunyuanVideoDownsampleCausal3D(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ kernel_size: int = 3,
+ bias: bool = True,
+ stride=2,
+ ) -> None:
+ super().__init__()
+ out_channels = out_channels or channels
+
+ self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class HunyuanVideoResnetBlockCausal3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ dropout: float = 0.0,
+ groups: int = 32,
+ eps: float = 1e-6,
+ non_linearity: str = "swish",
+ ) -> None:
+ super().__init__()
+ out_channels = out_channels or in_channels
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True)
+ self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0)
+
+ self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0)
+
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = hidden_states.contiguous()
+ residual = hidden_states
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ hidden_states = hidden_states + residual
+ return hidden_states
+
+
+class HunyuanVideoMidBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ add_attention: bool = True,
+ attention_head_dim: int = 1,
+ ) -> None:
+ super().__init__()
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+ self.add_attention = add_attention
+
+ # There is always at least one resnet
+ resnets = [
+ HunyuanVideoResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(
+ Attention(
+ in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+ else:
+ attentions.append(None)
+
+ resnets.append(
+ HunyuanVideoResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
+ attention_mask = prepare_causal_attention_mask(
+ num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
+ )
+ hidden_states = attn(hidden_states, attention_mask=attention_mask)
+ hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
+
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
+
+ else:
+ hidden_states = self.resnets[0](hidden_states)
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
+ attention_mask = prepare_causal_attention_mask(
+ num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
+ )
+ hidden_states = attn(hidden_states, attention_mask=attention_mask)
+ hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
+
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideoDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ add_downsample: bool = True,
+ downsample_stride: int = 2,
+ downsample_padding: int = 1,
+ ) -> None:
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ HunyuanVideoResnetBlockCausal3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ HunyuanVideoDownsampleCausal3D(
+ out_channels,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ stride=downsample_stride,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for resnet in self.resnets:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
+ else:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideoUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ add_upsample: bool = True,
+ upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
+ ) -> None:
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ HunyuanVideoResnetBlockCausal3D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ HunyuanVideoUpsampleCausal3D(
+ out_channels,
+ out_channels=out_channels,
+ upsample_factor=upsample_scale_factor,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for resnet in self.resnets:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
+
+ else:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideoEncoder3D(nn.Module):
+ r"""
+ Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str, ...] = (
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ double_z: bool = True,
+ mid_block_add_attention=True,
+ temporal_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 8,
+ ) -> None:
+ super().__init__()
+
+ self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ if down_block_type != "HunyuanVideoDownBlock3D":
+ raise ValueError(f"Unsupported down_block_type: {down_block_type}")
+
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
+ num_time_downsample_layers = int(np.log2(temporal_compression_ratio))
+
+ if temporal_compression_ratio == 4:
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
+ add_time_downsample = bool(
+ i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block
+ )
+ elif temporal_compression_ratio == 8:
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
+ add_time_downsample = bool(i < num_time_downsample_layers)
+ else:
+ raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}")
+
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
+
+ down_block = HunyuanVideoDownBlock3D(
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ downsample_stride=downsample_stride,
+ downsample_padding=0,
+ )
+
+ self.down_blocks.append(down_block)
+
+ self.mid_block = HunyuanVideoMidBlock3D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ add_attention=mid_block_add_attention,
+ )
+
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for down_block in self.down_blocks:
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
+
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+ else:
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states)
+
+ hidden_states = self.mid_block(hidden_states)
+
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideoDecoder3D(nn.Module):
+ r"""
+ Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = (
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ mid_block_add_attention=True,
+ time_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 8,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = HunyuanVideoMidBlock3D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ add_attention=mid_block_add_attention,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ if up_block_type != "HunyuanVideoUpBlock3D":
+ raise ValueError(f"Unsupported up_block_type: {up_block_type}")
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
+
+ if time_compression_ratio == 4:
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
+ add_time_upsample = bool(
+ i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block
+ )
+ else:
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
+
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
+
+ up_block = HunyuanVideoUpBlock3D(
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
+ upsample_scale_factor=upsample_scale_factor,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ )
+
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+
+ for up_block in self.up_blocks:
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
+ else:
+ hidden_states = self.mid_block(hidden_states)
+
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states)
+
+ # post-process
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
+ Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ latent_channels: int = 16,
+ down_block_types: Tuple[str, ...] = (
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ up_block_types: Tuple[str, ...] = (
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ block_out_channels: Tuple[int] = (128, 256, 512, 512),
+ layers_per_block: int = 2,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ scaling_factor: float = 0.476986,
+ spatial_compression_ratio: int = 8,
+ temporal_compression_ratio: int = 4,
+ mid_block_add_attention: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.time_compression_ratio = temporal_compression_ratio
+
+ self.encoder = HunyuanVideoEncoder3D(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ double_z=True,
+ mid_block_add_attention=mid_block_add_attention,
+ temporal_compression_ratio=temporal_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ )
+
+ self.decoder = HunyuanVideoDecoder3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ time_compression_ratio=temporal_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
+
+ self.spatial_compression_ratio = spatial_compression_ratio
+ self.temporal_compression_ratio = temporal_compression_ratio
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
+ # at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered.
+ self.use_framewise_encoding = True
+ self.use_framewise_decoding = True
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+ self.tile_sample_min_num_frames = 16
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+ self.tile_sample_stride_num_frames = 12
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_min_num_frames: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ tile_sample_stride_num_frames: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_min_num_frames (`int`, *optional*):
+ The minimum number of frames required for a sample to be separated into tiles across the frame
+ dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ tile_sample_stride_num_frames (`int`, *optional*):
+ The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts
+ produced across the frame dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = x.shape
+
+ if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames:
+ return self._temporal_tiled_encode(x)
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ x = self.encoder(x)
+ enc = self.quant_conv(x)
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+
+ if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
+ return self._temporal_tiled_decode(z, return_dict=return_dict)
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return enc
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+
+ batch_size, num_channels, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
+
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
+ blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
+
+ row = []
+ for i in range(0, num_frames, self.tile_sample_stride_num_frames):
+ tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
+ if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
+ tile = self.tiled_encode(tile)
+ else:
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ if i > 0:
+ tile = tile[:, :, 1:, :, :]
+ row.append(tile)
+
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
+ result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
+ else:
+ result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
+
+ enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
+ return enc
+
+ def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
+ blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
+
+ row = []
+ for i in range(0, num_frames, tile_latent_stride_num_frames):
+ tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
+ if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
+ decoded = self.tiled_decode(tile, return_dict=True).sample
+ else:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ if i > 0:
+ decoded = decoded[:, :, 1:, :, :]
+ row.append(decoded)
+
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :])
+ else:
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
+
+ dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
new file mode 100644
index 000000000000..2b2f77a5509d
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
@@ -0,0 +1,1557 @@
+# Copyright 2024 The Lightricks team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import RMSNorm
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+class LTXVideoCausalConv3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ dilation: Union[int, Tuple[int, int, int]] = 1,
+ groups: int = 1,
+ padding_mode: str = "zeros",
+ is_causal: bool = True,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.is_causal = is_causal
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
+
+ dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
+ stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
+ height_pad = self.kernel_size[1] // 2
+ width_pad = self.kernel_size[2] // 2
+ padding = (0, height_pad, width_pad)
+
+ self.conv = nn.Conv3d(
+ in_channels,
+ out_channels,
+ self.kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ padding=padding,
+ padding_mode=padding_mode,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ time_kernel_size = self.kernel_size[0]
+
+ if self.is_causal:
+ pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1))
+ hidden_states = torch.concatenate([pad_left, hidden_states], dim=2)
+ else:
+ pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
+ pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
+ hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2)
+
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class LTXVideoResnetBlock3d(nn.Module):
+ r"""
+ A 3D ResNet block used in the LTXVideo model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ elementwise_affine (`bool`, defaults to `False`):
+ Whether to enable elementwise affinity in the normalization layers.
+ non_linearity (`str`, defaults to `"swish"`):
+ Activation function to use.
+ conv_shortcut (bool, defaults to `False`):
+ Whether or not to use a convolution shortcut.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ dropout: float = 0.0,
+ eps: float = 1e-6,
+ elementwise_affine: bool = False,
+ non_linearity: str = "swish",
+ is_causal: bool = True,
+ inject_noise: bool = False,
+ timestep_conditioning: bool = False,
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
+ self.conv1 = LTXVideoCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
+ )
+
+ self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = LTXVideoCausalConv3d(
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
+ )
+
+ self.norm3 = None
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
+ self.conv_shortcut = LTXVideoCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
+ )
+
+ self.per_channel_scale1 = None
+ self.per_channel_scale2 = None
+ if inject_noise:
+ self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
+ self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
+
+ self.scale_shift_table = None
+ if timestep_conditioning:
+ self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
+
+ def forward(
+ self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
+ ) -> torch.Tensor:
+ hidden_states = inputs
+
+ hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.scale_shift_table is not None:
+ temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
+ shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
+ hidden_states = hidden_states * (1 + scale_1) + shift_1
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ if self.per_channel_scale1 is not None:
+ spatial_shape = hidden_states.shape[-2:]
+ spatial_noise = torch.randn(
+ spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
+ )[None]
+ hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
+
+ hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.scale_shift_table is not None:
+ hidden_states = hidden_states * (1 + scale_2) + shift_2
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.per_channel_scale2 is not None:
+ spatial_shape = hidden_states.shape[-2:]
+ spatial_noise = torch.randn(
+ spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
+ )[None]
+ hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
+
+ if self.norm3 is not None:
+ inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
+
+ if self.conv_shortcut is not None:
+ inputs = self.conv_shortcut(inputs)
+
+ hidden_states = hidden_states + inputs
+ return hidden_states
+
+
+class LTXVideoDownsampler3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ is_causal: bool = True,
+ padding_mode: str = "zeros",
+ ) -> None:
+ super().__init__()
+
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
+ self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
+
+ out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
+
+ self.conv = LTXVideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ is_causal=is_causal,
+ padding_mode=padding_mode,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
+
+ residual = (
+ hidden_states.unflatten(4, (-1, self.stride[2]))
+ .unflatten(3, (-1, self.stride[1]))
+ .unflatten(2, (-1, self.stride[0]))
+ )
+ residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
+ residual = residual.unflatten(1, (-1, self.group_size))
+ residual = residual.mean(dim=2)
+
+ hidden_states = self.conv(hidden_states)
+ hidden_states = (
+ hidden_states.unflatten(4, (-1, self.stride[2]))
+ .unflatten(3, (-1, self.stride[1]))
+ .unflatten(2, (-1, self.stride[0]))
+ )
+ hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class LTXVideoUpsampler3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ is_causal: bool = True,
+ residual: bool = False,
+ upscale_factor: int = 1,
+ padding_mode: str = "zeros",
+ ) -> None:
+ super().__init__()
+
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
+ self.residual = residual
+ self.upscale_factor = upscale_factor
+
+ out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
+
+ self.conv = LTXVideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=1,
+ is_causal=is_causal,
+ padding_mode=padding_mode,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ if self.residual:
+ residual = hidden_states.reshape(
+ batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
+ )
+ residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
+ residual = residual.repeat(1, repeats, 1, 1, 1)
+ residual = residual[:, :, self.stride[0] - 1 :]
+
+ hidden_states = self.conv(hidden_states)
+ hidden_states = hidden_states.reshape(
+ batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
+ )
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
+
+ if self.residual:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class LTXVideoDownBlock3D(nn.Module):
+ r"""
+ Down block used in the LTXVideo model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ spatio_temporal_scale (`bool`, defaults to `True`):
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+ Whether or not to downsample across temporal dimension.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ spatio_temporal_scale: bool = True,
+ is_causal: bool = True,
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ LTXVideoResnetBlock3d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ is_causal=is_causal,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.downsamplers = None
+ if spatio_temporal_scale:
+ self.downsamplers = nn.ModuleList(
+ [
+ LTXVideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=(2, 2, 2),
+ is_causal=is_causal,
+ )
+ ]
+ )
+
+ self.conv_out = None
+ if in_channels != out_channels:
+ self.conv_out = LTXVideoResnetBlock3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ is_causal=is_causal,
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ ) -> torch.Tensor:
+ r"""Forward method of the `LTXDownBlock3D` class."""
+
+ for i, resnet in enumerate(self.resnets):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
+ else:
+ hidden_states = resnet(hidden_states, temb, generator)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ if self.conv_out is not None:
+ hidden_states = self.conv_out(hidden_states, temb, generator)
+
+ return hidden_states
+
+
+class LTXVideo095DownBlock3D(nn.Module):
+ r"""
+ Down block used in the LTXVideo model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ spatio_temporal_scale (`bool`, defaults to `True`):
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+ Whether or not to downsample across temporal dimension.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ spatio_temporal_scale: bool = True,
+ is_causal: bool = True,
+ downsample_type: str = "conv",
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ LTXVideoResnetBlock3d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ is_causal=is_causal,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.downsamplers = None
+ if spatio_temporal_scale:
+ self.downsamplers = nn.ModuleList()
+
+ if downsample_type == "conv":
+ self.downsamplers.append(
+ LTXVideoCausalConv3d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=(2, 2, 2),
+ is_causal=is_causal,
+ )
+ )
+ elif downsample_type == "spatial":
+ self.downsamplers.append(
+ LTXVideoDownsampler3d(
+ in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
+ )
+ )
+ elif downsample_type == "temporal":
+ self.downsamplers.append(
+ LTXVideoDownsampler3d(
+ in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
+ )
+ )
+ elif downsample_type == "spatiotemporal":
+ self.downsamplers.append(
+ LTXVideoDownsampler3d(
+ in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
+ )
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ ) -> torch.Tensor:
+ r"""Forward method of the `LTXDownBlock3D` class."""
+
+ for i, resnet in enumerate(self.resnets):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
+ else:
+ hidden_states = resnet(hidden_states, temb, generator)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
+class LTXVideoMidBlock3d(nn.Module):
+ r"""
+ A middle block used in the LTXVideo model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ is_causal: bool = True,
+ inject_noise: bool = False,
+ timestep_conditioning: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.time_embedder = None
+ if timestep_conditioning:
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ LTXVideoResnetBlock3d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ is_causal=is_causal,
+ inject_noise=inject_noise,
+ timestep_conditioning=timestep_conditioning,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ ) -> torch.Tensor:
+ r"""Forward method of the `LTXMidBlock3D` class."""
+
+ if self.time_embedder is not None:
+ temb = self.time_embedder(
+ timestep=temb.flatten(),
+ resolution=None,
+ aspect_ratio=None,
+ batch_size=hidden_states.size(0),
+ hidden_dtype=hidden_states.dtype,
+ )
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
+
+ for i, resnet in enumerate(self.resnets):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
+ else:
+ hidden_states = resnet(hidden_states, temb, generator)
+
+ return hidden_states
+
+
+class LTXVideoUpBlock3d(nn.Module):
+ r"""
+ Up block used in the LTXVideo model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ spatio_temporal_scale (`bool`, defaults to `True`):
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+ Whether or not to downsample across temporal dimension.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ resnet_eps: float = 1e-6,
+ resnet_act_fn: str = "swish",
+ spatio_temporal_scale: bool = True,
+ is_causal: bool = True,
+ inject_noise: bool = False,
+ timestep_conditioning: bool = False,
+ upsample_residual: bool = False,
+ upscale_factor: int = 1,
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ self.time_embedder = None
+ if timestep_conditioning:
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
+
+ self.conv_in = None
+ if in_channels != out_channels:
+ self.conv_in = LTXVideoResnetBlock3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ is_causal=is_causal,
+ inject_noise=inject_noise,
+ timestep_conditioning=timestep_conditioning,
+ )
+
+ self.upsamplers = None
+ if spatio_temporal_scale:
+ self.upsamplers = nn.ModuleList(
+ [
+ LTXVideoUpsampler3d(
+ out_channels * upscale_factor,
+ stride=(2, 2, 2),
+ is_causal=is_causal,
+ residual=upsample_residual,
+ upscale_factor=upscale_factor,
+ )
+ ]
+ )
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(
+ LTXVideoResnetBlock3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ is_causal=is_causal,
+ inject_noise=inject_noise,
+ timestep_conditioning=timestep_conditioning,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ ) -> torch.Tensor:
+ if self.conv_in is not None:
+ hidden_states = self.conv_in(hidden_states, temb, generator)
+
+ if self.time_embedder is not None:
+ temb = self.time_embedder(
+ timestep=temb.flatten(),
+ resolution=None,
+ aspect_ratio=None,
+ batch_size=hidden_states.size(0),
+ hidden_dtype=hidden_states.dtype,
+ )
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ for i, resnet in enumerate(self.resnets):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
+ else:
+ hidden_states = resnet(hidden_states, temb, generator)
+
+ return hidden_states
+
+
+class LTXVideoEncoder3d(nn.Module):
+ r"""
+ The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
+ representation.
+
+ Args:
+ in_channels (`int`, defaults to 3):
+ Number of input channels.
+ out_channels (`int`, defaults to 128):
+ Number of latent channels.
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
+ The number of output channels for each block.
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
+ Whether a block should contain spatio-temporal downscaling layers or not.
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
+ The number of layers per block.
+ patch_size (`int`, defaults to `4`):
+ The size of spatial patches.
+ patch_size_t (`int`, defaults to `1`):
+ The size of temporal patches.
+ resnet_norm_eps (`float`, defaults to `1e-6`):
+ Epsilon value for ResNet normalization layers.
+ is_causal (`bool`, defaults to `True`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 128,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ down_block_types: Tuple[str, ...] = (
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ ),
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
+ layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
+ downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
+ patch_size: int = 4,
+ patch_size_t: int = 1,
+ resnet_norm_eps: float = 1e-6,
+ is_causal: bool = True,
+ ):
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.in_channels = in_channels * patch_size**2
+
+ output_channel = block_out_channels[0]
+
+ self.conv_in = LTXVideoCausalConv3d(
+ in_channels=self.in_channels,
+ out_channels=output_channel,
+ kernel_size=3,
+ stride=1,
+ is_causal=is_causal,
+ )
+
+ # down blocks
+ is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
+ num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
+ self.down_blocks = nn.ModuleList([])
+ for i in range(num_block_out_channels):
+ input_channel = output_channel
+ if not is_ltx_095:
+ output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
+ else:
+ output_channel = block_out_channels[i + 1]
+
+ if down_block_types[i] == "LTXVideoDownBlock3D":
+ down_block = LTXVideoDownBlock3D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ num_layers=layers_per_block[i],
+ resnet_eps=resnet_norm_eps,
+ spatio_temporal_scale=spatio_temporal_scaling[i],
+ is_causal=is_causal,
+ )
+ elif down_block_types[i] == "LTXVideo095DownBlock3D":
+ down_block = LTXVideo095DownBlock3D(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ num_layers=layers_per_block[i],
+ resnet_eps=resnet_norm_eps,
+ spatio_temporal_scale=spatio_temporal_scaling[i],
+ is_causal=is_causal,
+ downsample_type=downsample_type[i],
+ )
+ else:
+ raise ValueError(f"Unknown down block type: {down_block_types[i]}")
+
+ self.down_blocks.append(down_block)
+
+ # mid block
+ self.mid_block = LTXVideoMidBlock3d(
+ in_channels=output_channel,
+ num_layers=layers_per_block[-1],
+ resnet_eps=resnet_norm_eps,
+ is_causal=is_causal,
+ )
+
+ # out
+ self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
+ self.conv_act = nn.SiLU()
+ self.conv_out = LTXVideoCausalConv3d(
+ in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ r"""The forward method of the `LTXVideoEncoder3d` class."""
+
+ p = self.patch_size
+ p_t = self.patch_size_t
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ hidden_states = hidden_states.reshape(
+ batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p
+ )
+ # Thanks for driving me insane with the weird patching order :(
+ hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)
+ hidden_states = self.conv_in(hidden_states)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for down_block in self.down_blocks:
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
+
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+ else:
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states)
+
+ hidden_states = self.mid_block(hidden_states)
+
+ hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ last_channel = hidden_states[:, -1:]
+ last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)
+ hidden_states = torch.cat([hidden_states, last_channel], dim=1)
+
+ return hidden_states
+
+
+class LTXVideoDecoder3d(nn.Module):
+ r"""
+ The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
+ sample.
+
+ Args:
+ in_channels (`int`, defaults to 128):
+ Number of latent channels.
+ out_channels (`int`, defaults to 3):
+ Number of output channels.
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
+ The number of output channels for each block.
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
+ Whether a block should contain spatio-temporal upscaling layers or not.
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
+ The number of layers per block.
+ patch_size (`int`, defaults to `4`):
+ The size of spatial patches.
+ patch_size_t (`int`, defaults to `1`):
+ The size of temporal patches.
+ resnet_norm_eps (`float`, defaults to `1e-6`):
+ Epsilon value for ResNet normalization layers.
+ is_causal (`bool`, defaults to `False`):
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
+ timestep_conditioning (`bool`, defaults to `False`):
+ Whether to condition the model on timesteps.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 128,
+ out_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
+ layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
+ patch_size: int = 4,
+ patch_size_t: int = 1,
+ resnet_norm_eps: float = 1e-6,
+ is_causal: bool = False,
+ inject_noise: Tuple[bool, ...] = (False, False, False, False),
+ timestep_conditioning: bool = False,
+ upsample_residual: Tuple[bool, ...] = (False, False, False, False),
+ upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.out_channels = out_channels * patch_size**2
+
+ block_out_channels = tuple(reversed(block_out_channels))
+ spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
+ layers_per_block = tuple(reversed(layers_per_block))
+ inject_noise = tuple(reversed(inject_noise))
+ upsample_residual = tuple(reversed(upsample_residual))
+ upsample_factor = tuple(reversed(upsample_factor))
+ output_channel = block_out_channels[0]
+
+ self.conv_in = LTXVideoCausalConv3d(
+ in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
+ )
+
+ self.mid_block = LTXVideoMidBlock3d(
+ in_channels=output_channel,
+ num_layers=layers_per_block[0],
+ resnet_eps=resnet_norm_eps,
+ is_causal=is_causal,
+ inject_noise=inject_noise[0],
+ timestep_conditioning=timestep_conditioning,
+ )
+
+ # up blocks
+ num_block_out_channels = len(block_out_channels)
+ self.up_blocks = nn.ModuleList([])
+ for i in range(num_block_out_channels):
+ input_channel = output_channel // upsample_factor[i]
+ output_channel = block_out_channels[i] // upsample_factor[i]
+
+ up_block = LTXVideoUpBlock3d(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ num_layers=layers_per_block[i + 1],
+ resnet_eps=resnet_norm_eps,
+ spatio_temporal_scale=spatio_temporal_scaling[i],
+ is_causal=is_causal,
+ inject_noise=inject_noise[i + 1],
+ timestep_conditioning=timestep_conditioning,
+ upsample_residual=upsample_residual[i],
+ upscale_factor=upsample_factor[i],
+ )
+
+ self.up_blocks.append(up_block)
+
+ # out
+ self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
+ self.conv_act = nn.SiLU()
+ self.conv_out = LTXVideoCausalConv3d(
+ in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
+ )
+
+ # timestep embedding
+ self.time_embedder = None
+ self.scale_shift_table = None
+ self.timestep_scale_multiplier = None
+ if timestep_conditioning:
+ self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+
+ if self.timestep_scale_multiplier is not None:
+ temb = temb * self.timestep_scale_multiplier
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
+
+ for up_block in self.up_blocks:
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb)
+ else:
+ hidden_states = self.mid_block(hidden_states, temb)
+
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states, temb)
+
+ hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.time_embedder is not None:
+ temb = self.time_embedder(
+ timestep=temb.flatten(),
+ resolution=None,
+ aspect_ratio=None,
+ batch_size=hidden_states.size(0),
+ hidden_dtype=hidden_states.dtype,
+ )
+ temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
+ temb = temb + self.scale_shift_table[None, ..., None, None, None]
+ shift, scale = temb.unbind(dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ p = self.patch_size
+ p_t = self.patch_size_t
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width)
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ return hidden_states
+
+
+class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
+ [LTX](https://huggingface.co/Lightricks/LTX-Video).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Args:
+ in_channels (`int`, defaults to `3`):
+ Number of input channels.
+ out_channels (`int`, defaults to `3`):
+ Number of output channels.
+ latent_channels (`int`, defaults to `128`):
+ Number of latent channels.
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
+ The number of output channels for each block.
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
+ Whether a block should contain spatio-temporal downscaling or not.
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
+ The number of layers per block.
+ patch_size (`int`, defaults to `4`):
+ The size of spatial patches.
+ patch_size_t (`int`, defaults to `1`):
+ The size of temporal patches.
+ resnet_norm_eps (`float`, defaults to `1e-6`):
+ Epsilon value for ResNet normalization layers.
+ scaling_factor (`float`, *optional*, defaults to `1.0`):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ encoder_causal (`bool`, defaults to `True`):
+ Whether the encoder should behave causally (future frames depend only on past frames) or not.
+ decoder_causal (`bool`, defaults to `False`):
+ Whether the decoder should behave causally (future frames depend only on past frames) or not.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ latent_channels: int = 128,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ down_block_types: Tuple[str, ...] = (
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ "LTXVideoDownBlock3D",
+ ),
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
+ decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
+ decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
+ decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
+ downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
+ upsample_residual: Tuple[bool, ...] = (False, False, False, False),
+ upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
+ timestep_conditioning: bool = False,
+ patch_size: int = 4,
+ patch_size_t: int = 1,
+ resnet_norm_eps: float = 1e-6,
+ scaling_factor: float = 1.0,
+ encoder_causal: bool = True,
+ decoder_causal: bool = False,
+ spatial_compression_ratio: int = None,
+ temporal_compression_ratio: int = None,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = LTXVideoEncoder3d(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ block_out_channels=block_out_channels,
+ down_block_types=down_block_types,
+ spatio_temporal_scaling=spatio_temporal_scaling,
+ layers_per_block=layers_per_block,
+ downsample_type=downsample_type,
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ resnet_norm_eps=resnet_norm_eps,
+ is_causal=encoder_causal,
+ )
+ self.decoder = LTXVideoDecoder3d(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=decoder_block_out_channels,
+ spatio_temporal_scaling=decoder_spatio_temporal_scaling,
+ layers_per_block=decoder_layers_per_block,
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ resnet_norm_eps=resnet_norm_eps,
+ is_causal=decoder_causal,
+ timestep_conditioning=timestep_conditioning,
+ inject_noise=decoder_inject_noise,
+ upsample_residual=upsample_residual,
+ upsample_factor=upsample_factor,
+ )
+
+ latents_mean = torch.zeros((latent_channels,), requires_grad=False)
+ latents_std = torch.ones((latent_channels,), requires_grad=False)
+ self.register_buffer("latents_mean", latents_mean, persistent=True)
+ self.register_buffer("latents_std", latents_std, persistent=True)
+
+ self.spatial_compression_ratio = (
+ patch_size * 2 ** sum(spatio_temporal_scaling)
+ if spatial_compression_ratio is None
+ else spatial_compression_ratio
+ )
+ self.temporal_compression_ratio = (
+ patch_size_t * 2 ** sum(spatio_temporal_scaling)
+ if temporal_compression_ratio is None
+ else temporal_compression_ratio
+ )
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
+ self.use_framewise_encoding = False
+ self.use_framewise_decoding = False
+
+ # This can be configured based on the amount of GPU memory available.
+ # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
+ # Setting it to higher values results in higher memory usage.
+ self.num_sample_frames_batch_size = 16
+ self.num_latent_frames_batch_size = 2
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 512
+ self.tile_sample_min_width = 512
+ self.tile_sample_min_num_frames = 16
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 448
+ self.tile_sample_stride_width = 448
+ self.tile_sample_stride_num_frames = 8
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_min_num_frames: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ tile_sample_stride_num_frames: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = x.shape
+
+ if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
+ return self._temporal_tiled_encode(x)
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ enc = self.encoder(x)
+
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(
+ self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+
+ if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
+ return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z, temb, return_dict=return_dict)
+
+ dec = self.decoder(z, temb)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(
+ self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ if temb is not None:
+ decoded_slices = [
+ self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
+ ]
+ else:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z, temb).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ time = self.encoder(
+ x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
+ )
+
+ row.append(time)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return enc
+
+ def tiled_decode(
+ self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+
+ batch_size, num_channels, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
+
+ row.append(time)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
+
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
+ blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
+
+ row = []
+ for i in range(0, num_frames, self.tile_sample_stride_num_frames):
+ tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
+ if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
+ tile = self.tiled_encode(tile)
+ else:
+ tile = self.encoder(tile)
+ if i > 0:
+ tile = tile[:, :, 1:, :, :]
+ row.append(tile)
+
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
+ result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
+ else:
+ result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
+
+ enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
+ return enc
+
+ def _temporal_tiled_decode(
+ self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
+ tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
+ blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
+
+ row = []
+ for i in range(0, num_frames, tile_latent_stride_num_frames):
+ tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
+ if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
+ decoded = self.tiled_decode(tile, temb, return_dict=True).sample
+ else:
+ decoded = self.decoder(tile, temb)
+ if i > 0:
+ decoded = decoded[:, :, :-1, :, :]
+ row.append(decoded)
+
+ result_row = []
+ for i, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_t(row[i - 1], tile, blend_num_frames)
+ tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :]
+ result_row.append(tile)
+ else:
+ result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
+
+ dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, temb)
+ if not return_dict:
+ return (dec.sample,)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py b/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py
new file mode 100644
index 000000000000..7b53192033dc
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py
@@ -0,0 +1,1094 @@
+# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class EasyAnimateCausalConv3d(nn.Conv3d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, ...]] = 3,
+ stride: Union[int, Tuple[int, ...]] = 1,
+ padding: Union[int, Tuple[int, ...]] = 1,
+ dilation: Union[int, Tuple[int, ...]] = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = "zeros",
+ ):
+ # Ensure kernel_size, stride, and dilation are tuples of length 3
+ kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
+ assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
+
+ stride = stride if isinstance(stride, tuple) else (stride,) * 3
+ assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
+
+ dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
+ assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
+
+ # Unpack kernel size, stride, and dilation for temporal, height, and width dimensions
+ t_ks, h_ks, w_ks = kernel_size
+ self.t_stride, h_stride, w_stride = stride
+ t_dilation, h_dilation, w_dilation = dilation
+
+ # Calculate padding for temporal dimension to maintain causality
+ t_pad = (t_ks - 1) * t_dilation
+
+ # Calculate padding for height and width dimensions based on the padding parameter
+ if padding is None:
+ h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
+ w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
+ elif isinstance(padding, int):
+ h_pad = w_pad = padding
+ else:
+ assert NotImplementedError
+
+ # Store temporal padding and initialize flags and previous features cache
+ self.temporal_padding = t_pad
+ self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
+
+ self.prev_features = None
+
+ # Initialize the parent class with modified padding
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=(0, h_pad, w_pad),
+ groups=groups,
+ bias=bias,
+ padding_mode=padding_mode,
+ )
+
+ def _clear_conv_cache(self):
+ del self.prev_features
+ self.prev_features = None
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # Ensure input tensor is of the correct type
+ dtype = hidden_states.dtype
+ if self.prev_features is None:
+ # Pad the input tensor in the temporal dimension to maintain causality
+ hidden_states = F.pad(
+ hidden_states,
+ pad=(0, 0, 0, 0, self.temporal_padding, 0),
+ mode="replicate", # TODO: check if this is necessary
+ )
+ hidden_states = hidden_states.to(dtype=dtype)
+
+ # Clear cache before processing and store previous features for causality
+ self._clear_conv_cache()
+ self.prev_features = hidden_states[:, :, -self.temporal_padding :].clone()
+
+ # Process the input tensor in chunks along the temporal dimension
+ num_frames = hidden_states.size(2)
+ outputs = []
+ i = 0
+ while i + self.temporal_padding + 1 <= num_frames:
+ out = super().forward(hidden_states[:, :, i : i + self.temporal_padding + 1])
+ i += self.t_stride
+ outputs.append(out)
+ return torch.concat(outputs, 2)
+ else:
+ # Concatenate previous features with the input tensor for continuous temporal processing
+ if self.t_stride == 2:
+ hidden_states = torch.concat(
+ [self.prev_features[:, :, -(self.temporal_padding - 1) :], hidden_states], dim=2
+ )
+ else:
+ hidden_states = torch.concat([self.prev_features, hidden_states], dim=2)
+ hidden_states = hidden_states.to(dtype=dtype)
+
+ # Clear cache and update previous features
+ self._clear_conv_cache()
+ self.prev_features = hidden_states[:, :, -self.temporal_padding :].clone()
+
+ # Process the concatenated tensor in chunks along the temporal dimension
+ num_frames = hidden_states.size(2)
+ outputs = []
+ i = 0
+ while i + self.temporal_padding + 1 <= num_frames:
+ out = super().forward(hidden_states[:, :, i : i + self.temporal_padding + 1])
+ i += self.t_stride
+ outputs.append(out)
+ return torch.concat(outputs, 2)
+
+
+class EasyAnimateResidualBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ non_linearity: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ spatial_group_norm: bool = True,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+
+ self.output_scale_factor = output_scale_factor
+
+ # Group normalization for input tensor
+ self.norm1 = nn.GroupNorm(
+ num_groups=norm_num_groups,
+ num_channels=in_channels,
+ eps=norm_eps,
+ affine=True,
+ )
+ self.nonlinearity = get_activation(non_linearity)
+ self.conv1 = EasyAnimateCausalConv3d(in_channels, out_channels, kernel_size=3)
+
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = EasyAnimateCausalConv3d(out_channels, out_channels, kernel_size=3)
+
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1)
+ else:
+ self.shortcut = nn.Identity()
+
+ self.spatial_group_norm = spatial_group_norm
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ shortcut = self.shortcut(hidden_states)
+
+ if self.spatial_group_norm:
+ batch_size = hidden_states.size(0)
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
+ 0, 2, 1, 3, 4
+ ) # [B * T, C, H, W] -> [B, C, T, H, W]
+ else:
+ hidden_states = self.norm1(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ if self.spatial_group_norm:
+ batch_size = hidden_states.size(0)
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
+ 0, 2, 1, 3, 4
+ ) # [B * T, C, H, W] -> [B, C, T, H, W]
+ else:
+ hidden_states = self.norm2(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ return (hidden_states + shortcut) / self.output_scale_factor
+
+
+class EasyAnimateDownsampler3D(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: tuple = (2, 2, 2)):
+ super().__init__()
+
+ self.conv = EasyAnimateCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = F.pad(hidden_states, (0, 1, 0, 1))
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class EasyAnimateUpsampler3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ temporal_upsample: bool = False,
+ spatial_group_norm: bool = True,
+ ):
+ super().__init__()
+ out_channels = out_channels or in_channels
+
+ self.temporal_upsample = temporal_upsample
+ self.spatial_group_norm = spatial_group_norm
+
+ self.conv = EasyAnimateCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
+ )
+ self.prev_features = None
+
+ def _clear_conv_cache(self):
+ del self.prev_features
+ self.prev_features = None
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = F.interpolate(hidden_states, scale_factor=(1, 2, 2), mode="nearest")
+ hidden_states = self.conv(hidden_states)
+
+ if self.temporal_upsample:
+ if self.prev_features is None:
+ self.prev_features = hidden_states
+ else:
+ hidden_states = F.interpolate(
+ hidden_states,
+ scale_factor=(2, 1, 1),
+ mode="trilinear" if not self.spatial_group_norm else "nearest",
+ )
+ return hidden_states
+
+
+class EasyAnimateDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ spatial_group_norm: bool = True,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ add_temporal_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ EasyAnimateResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ spatial_group_norm=spatial_group_norm,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_downsample and add_temporal_downsample:
+ self.downsampler = EasyAnimateDownsampler3D(out_channels, out_channels, kernel_size=3, stride=(2, 2, 2))
+ self.spatial_downsample_factor = 2
+ self.temporal_downsample_factor = 2
+ elif add_downsample and not add_temporal_downsample:
+ self.downsampler = EasyAnimateDownsampler3D(out_channels, out_channels, kernel_size=3, stride=(1, 2, 2))
+ self.spatial_downsample_factor = 2
+ self.temporal_downsample_factor = 1
+ else:
+ self.downsampler = None
+ self.spatial_downsample_factor = 1
+ self.temporal_downsample_factor = 1
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for conv in self.convs:
+ hidden_states = conv(hidden_states)
+ if self.downsampler is not None:
+ hidden_states = self.downsampler(hidden_states)
+ return hidden_states
+
+
+class EasyAnimateUpBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ spatial_group_norm: bool = False,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ add_temporal_upsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ EasyAnimateResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ spatial_group_norm=spatial_group_norm,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_upsample:
+ self.upsampler = EasyAnimateUpsampler3D(
+ in_channels,
+ in_channels,
+ temporal_upsample=add_temporal_upsample,
+ spatial_group_norm=spatial_group_norm,
+ )
+ else:
+ self.upsampler = None
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for conv in self.convs:
+ hidden_states = conv(hidden_states)
+ if self.upsampler is not None:
+ hidden_states = self.upsampler(hidden_states)
+ return hidden_states
+
+
+class EasyAnimateMidBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ spatial_group_norm: bool = True,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+
+ norm_num_groups = norm_num_groups if norm_num_groups is not None else min(in_channels // 4, 32)
+
+ self.convs = nn.ModuleList(
+ [
+ EasyAnimateResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ spatial_group_norm=spatial_group_norm,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ ]
+ )
+
+ for _ in range(num_layers - 1):
+ self.convs.append(
+ EasyAnimateResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ spatial_group_norm=spatial_group_norm,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.convs[0](hidden_states)
+ for resnet in self.convs[1:]:
+ hidden_states = resnet(hidden_states)
+ return hidden_states
+
+
+class EasyAnimateEncoder(nn.Module):
+ r"""
+ Causal encoder for 3D video-like data used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 8,
+ down_block_types: Tuple[str, ...] = (
+ "SpatialDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ double_z: bool = True,
+ spatial_group_norm: bool = False,
+ ):
+ super().__init__()
+
+ # 1. Input convolution
+ self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
+
+ # 2. Down blocks
+ self.down_blocks = nn.ModuleList([])
+ output_channels = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channels = output_channels
+ output_channels = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ if down_block_type == "SpatialDownBlock3D":
+ down_block = EasyAnimateDownBlock3D(
+ in_channels=input_channels,
+ out_channels=output_channels,
+ num_layers=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=1e-6,
+ spatial_group_norm=spatial_group_norm,
+ add_downsample=not is_final_block,
+ add_temporal_downsample=False,
+ )
+ elif down_block_type == "SpatialTemporalDownBlock3D":
+ down_block = EasyAnimateDownBlock3D(
+ in_channels=input_channels,
+ out_channels=output_channels,
+ num_layers=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=1e-6,
+ spatial_group_norm=spatial_group_norm,
+ add_downsample=not is_final_block,
+ add_temporal_downsample=True,
+ )
+ else:
+ raise ValueError(f"Unknown up block type: {down_block_type}")
+ self.down_blocks.append(down_block)
+
+ # 3. Middle block
+ self.mid_block = EasyAnimateMidBlock3d(
+ in_channels=block_out_channels[-1],
+ num_layers=layers_per_block,
+ act_fn=act_fn,
+ spatial_group_norm=spatial_group_norm,
+ norm_num_groups=norm_num_groups,
+ norm_eps=1e-6,
+ dropout=0,
+ output_scale_factor=1,
+ )
+
+ # 4. Output normalization & convolution
+ self.spatial_group_norm = spatial_group_norm
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[-1],
+ num_groups=norm_num_groups,
+ eps=1e-6,
+ )
+ self.conv_act = get_activation(act_fn)
+
+ # Initialize the output convolution layer
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = EasyAnimateCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # hidden_states: (B, C, T, H, W)
+ hidden_states = self.conv_in(hidden_states)
+
+ for down_block in self.down_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
+ else:
+ hidden_states = down_block(hidden_states)
+
+ hidden_states = self.mid_block(hidden_states)
+
+ if self.spatial_group_norm:
+ batch_size = hidden_states.size(0)
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ else:
+ hidden_states = self.conv_norm_out(hidden_states)
+
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class EasyAnimateDecoder(nn.Module):
+ r"""
+ Causal decoder for 3D video-like data used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: int = 8,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = (
+ "SpatialUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ ),
+ block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ spatial_group_norm: bool = False,
+ ):
+ super().__init__()
+
+ # 1. Input convolution
+ self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3)
+
+ # 2. Middle block
+ self.mid_block = EasyAnimateMidBlock3d(
+ in_channels=block_out_channels[-1],
+ num_layers=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=1e-6,
+ dropout=0,
+ output_scale_factor=1,
+ )
+
+ # 3. Up blocks
+ self.up_blocks = nn.ModuleList([])
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channels = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ input_channels = output_channels
+ output_channels = reversed_block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ # Create and append up block to up_blocks
+ if up_block_type == "SpatialUpBlock3D":
+ up_block = EasyAnimateUpBlock3d(
+ in_channels=input_channels,
+ out_channels=output_channels,
+ num_layers=layers_per_block + 1,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=1e-6,
+ spatial_group_norm=spatial_group_norm,
+ add_upsample=not is_final_block,
+ add_temporal_upsample=False,
+ )
+ elif up_block_type == "SpatialTemporalUpBlock3D":
+ up_block = EasyAnimateUpBlock3d(
+ in_channels=input_channels,
+ out_channels=output_channels,
+ num_layers=layers_per_block + 1,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=1e-6,
+ spatial_group_norm=spatial_group_norm,
+ add_upsample=not is_final_block,
+ add_temporal_upsample=True,
+ )
+ else:
+ raise ValueError(f"Unknown up block type: {up_block_type}")
+ self.up_blocks.append(up_block)
+
+ # Output normalization and activation
+ self.spatial_group_norm = spatial_group_norm
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0],
+ num_groups=norm_num_groups,
+ eps=1e-6,
+ )
+ self.conv_act = get_activation(act_fn)
+
+ # Output convolution layer
+ self.conv_out = EasyAnimateCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # hidden_states: (B, C, T, H, W)
+ hidden_states = self.conv_in(hidden_states)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+ else:
+ hidden_states = self.mid_block(hidden_states)
+
+ for up_block in self.up_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
+ else:
+ hidden_states = up_block(hidden_states)
+
+ if self.spatial_group_norm:
+ batch_size = hidden_states.size(0)
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
+ hidden_states = self.conv_norm_out(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
+ 0, 2, 1, 3, 4
+ ) # [B * T, C, H, W] -> [B, C, T, H, W]
+ else:
+ hidden_states = self.conv_norm_out(hidden_states)
+
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
+ model is used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ latent_channels: int = 16,
+ out_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = [128, 256, 512, 512],
+ down_block_types: Tuple[str, ...] = [
+ "SpatialDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ ],
+ up_block_types: Tuple[str, ...] = [
+ "SpatialUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ ],
+ layers_per_block: int = 2,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ scaling_factor: float = 0.7125,
+ spatial_group_norm: bool = True,
+ ):
+ super().__init__()
+
+ # Initialize the encoder
+ self.encoder = EasyAnimateEncoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ double_z=True,
+ spatial_group_norm=spatial_group_norm,
+ )
+
+ # Initialize the decoder
+ self.decoder = EasyAnimateDecoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ spatial_group_norm=spatial_group_norm,
+ )
+
+ # Initialize convolution layers for quantization and post-quantization
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
+
+ self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1)
+ self.temporal_compression_ratio = 2 ** (len(block_out_channels) - 2)
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_size`), the memory requirement can be lowered.
+ self.use_framewise_encoding = False
+ self.use_framewise_decoding = False
+
+ # Assign mini-batch sizes for encoder and decoder
+ self.num_sample_frames_batch_size = 4
+ self.num_latent_frames_batch_size = 1
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 512
+ self.tile_sample_min_width = 512
+ self.tile_sample_min_num_frames = 4
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 448
+ self.tile_sample_stride_width = 448
+ self.tile_sample_stride_num_frames = 8
+
+ def _clear_conv_cache(self):
+ # Clear cache for convolutional layers if needed
+ for name, module in self.named_modules():
+ if isinstance(module, EasyAnimateCausalConv3d):
+ module._clear_conv_cache()
+ if isinstance(module, EasyAnimateUpsampler3D):
+ module._clear_conv_cache()
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_min_num_frames: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ tile_sample_stride_num_frames: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.use_framewise_decoding = True
+ self.use_framewise_encoding = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ @apply_forward_hook
+ def _encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_height or x.shape[-2] > self.tile_sample_min_width):
+ return self.tiled_encode(x, return_dict=return_dict)
+
+ first_frames = self.encoder(x[:, :, :1, :, :])
+ h = [first_frames]
+ for i in range(1, x.shape[2], self.num_sample_frames_batch_size):
+ next_frames = self.encoder(x[:, :, i : i + self.num_sample_frames_batch_size, :, :])
+ h.append(next_frames)
+ h = torch.cat(h, dim=2)
+ moments = self.quant_conv(h)
+
+ self._clear_conv_cache()
+ return moments
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ z = self.post_quant_conv(z)
+
+ # Process the first frame and save the result
+ first_frames = self.decoder(z[:, :, :1, :, :])
+ # Initialize the list to store the processed frames, starting with the first frame
+ dec = [first_frames]
+ # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
+ for i in range(1, z.shape[2], self.num_latent_frames_batch_size):
+ next_frames = self.decoder(z[:, :, i : i + self.num_latent_frames_batch_size, :, :])
+ dec.append(next_frames)
+ # Concatenate all processed frames along the channel dimension
+ dec = torch.cat(dec, dim=2)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ self._clear_conv_cache()
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+
+ first_frames = self.encoder(tile[:, :, 0:1, :, :])
+ tile_h = [first_frames]
+ for k in range(1, num_frames, self.num_sample_frames_batch_size):
+ next_frames = self.encoder(tile[:, :, k : k + self.num_sample_frames_batch_size, :, :])
+ tile_h.append(next_frames)
+ tile = torch.cat(tile_h, dim=2)
+ tile = self.quant_conv(tile)
+ self._clear_conv_cache()
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :latent_height, :latent_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ moments = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return moments
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ tile = z[
+ :,
+ :,
+ :,
+ i : i + tile_latent_min_height,
+ j : j + tile_latent_min_width,
+ ]
+ tile = self.post_quant_conv(tile)
+
+ # Process the first frame and save the result
+ first_frames = self.decoder(tile[:, :, :1, :, :])
+ # Initialize the list to store the processed frames, starting with the first frame
+ tile_dec = [first_frames]
+ # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder
+ for k in range(1, num_frames, self.num_latent_frames_batch_size):
+ next_frames = self.decoder(tile[:, :, k : k + self.num_latent_frames_batch_size, :, :])
+ tile_dec.append(next_frames)
+ # Concatenate all processed frames along the channel dimension
+ decoded = torch.cat(tile_dec, dim=2)
+ self._clear_conv_cache()
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
new file mode 100644
index 000000000000..d69ec6252b00
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
@@ -0,0 +1,1129 @@
+# Copyright 2024 The Mochi team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class MochiChunkedGroupNorm3D(nn.Module):
+ r"""
+ Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group
+ normalization.
+
+ Args:
+ num_channels (int): Number of channels expected in input
+ num_groups (int, optional): Number of groups to separate the channels into. Default: 32
+ affine (bool, optional): If True, this module has learnable affine parameters. Default: True
+ chunk_size (int, optional): Size of each chunk for processing. Default: 8
+
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ num_groups: int = 32,
+ affine: bool = True,
+ chunk_size: int = 8,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine)
+ self.chunk_size = chunk_size
+
+ def forward(self, x: torch.Tensor = None) -> torch.Tensor:
+ batch_size = x.size(0)
+
+ x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0)
+ output = output.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+
+ return output
+
+
+class MochiResnetBlock3D(nn.Module):
+ r"""
+ A 3D ResNet block used in the Mochi model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ non_linearity (`str`, defaults to `"swish"`):
+ Activation function to use.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ act_fn: str = "swish",
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.nonlinearity = get_activation(act_fn)
+
+ self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels)
+ self.conv1 = CogVideoXCausalConv3d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate"
+ )
+ self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels)
+ self.conv2 = CogVideoXCausalConv3d(
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate"
+ )
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ new_conv_cache = {}
+ conv_cache = conv_cache or {}
+
+ hidden_states = inputs
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
+
+ hidden_states = hidden_states + inputs
+ return hidden_states, new_conv_cache
+
+
+class MochiDownBlock3D(nn.Module):
+ r"""
+ An downsampling block used in the Mochi model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet blocks in the block.
+ temporal_expansion (`int`, defaults to `2`):
+ Temporal expansion factor.
+ spatial_expansion (`int`, defaults to `2`):
+ Spatial expansion factor.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ temporal_expansion: int = 2,
+ spatial_expansion: int = 2,
+ add_attention: bool = True,
+ ):
+ super().__init__()
+ self.temporal_expansion = temporal_expansion
+ self.spatial_expansion = spatial_expansion
+
+ self.conv_in = CogVideoXCausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(temporal_expansion, spatial_expansion, spatial_expansion),
+ stride=(temporal_expansion, spatial_expansion, spatial_expansion),
+ pad_mode="replicate",
+ )
+
+ resnets = []
+ norms = []
+ attentions = []
+ for _ in range(num_layers):
+ resnets.append(MochiResnetBlock3D(in_channels=out_channels))
+ if add_attention:
+ norms.append(MochiChunkedGroupNorm3D(num_channels=out_channels))
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ heads=out_channels // 32,
+ dim_head=32,
+ qk_norm="l2",
+ is_causal=True,
+ processor=MochiVaeAttnProcessor2_0(),
+ )
+ )
+ else:
+ norms.append(None)
+ attentions.append(None)
+
+ self.resnets = nn.ModuleList(resnets)
+ self.norms = nn.ModuleList(norms)
+ self.attentions = nn.ModuleList(attentions)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
+ chunk_size: int = 2**15,
+ ) -> torch.Tensor:
+ r"""Forward method of the `MochiUpBlock3D` class."""
+
+ new_conv_cache = {}
+ conv_cache = conv_cache or {}
+
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(hidden_states)
+
+ for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
+ conv_cache_key = f"resnet_{i}"
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ resnet,
+ hidden_states,
+ conv_cache=conv_cache.get(conv_cache_key),
+ )
+ else:
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ )
+
+ if attn is not None:
+ residual = hidden_states
+ hidden_states = norm(hidden_states)
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous()
+
+ # Perform attention in chunks to avoid following error:
+ # RuntimeError: CUDA error: invalid configuration argument
+ if hidden_states.size(0) <= chunk_size:
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states_chunks = []
+ for i in range(0, hidden_states.size(0), chunk_size):
+ hidden_states_chunk = hidden_states[i : i + chunk_size]
+ hidden_states_chunk = attn(hidden_states_chunk)
+ hidden_states_chunks.append(hidden_states_chunk)
+ hidden_states = torch.cat(hidden_states_chunks)
+
+ hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2)
+
+ hidden_states = residual + hidden_states
+
+ return hidden_states, new_conv_cache
+
+
+class MochiMidBlock3D(nn.Module):
+ r"""
+ A middle block used in the Mochi model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ num_layers (`int`, defaults to `3`):
+ Number of resnet blocks in the block.
+ """
+
+ def __init__(
+ self,
+ in_channels: int, # 768
+ num_layers: int = 3,
+ add_attention: bool = True,
+ ):
+ super().__init__()
+
+ resnets = []
+ norms = []
+ attentions = []
+
+ for _ in range(num_layers):
+ resnets.append(MochiResnetBlock3D(in_channels=in_channels))
+
+ if add_attention:
+ norms.append(MochiChunkedGroupNorm3D(num_channels=in_channels))
+ attentions.append(
+ Attention(
+ query_dim=in_channels,
+ heads=in_channels // 32,
+ dim_head=32,
+ qk_norm="l2",
+ is_causal=True,
+ processor=MochiVaeAttnProcessor2_0(),
+ )
+ )
+ else:
+ norms.append(None)
+ attentions.append(None)
+
+ self.resnets = nn.ModuleList(resnets)
+ self.norms = nn.ModuleList(norms)
+ self.attentions = nn.ModuleList(attentions)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ r"""Forward method of the `MochiMidBlock3D` class."""
+
+ new_conv_cache = {}
+ conv_cache = conv_cache or {}
+
+ for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
+ conv_cache_key = f"resnet_{i}"
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ )
+ else:
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ )
+
+ if attn is not None:
+ residual = hidden_states
+ hidden_states = norm(hidden_states)
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous()
+ hidden_states = attn(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2)
+
+ hidden_states = residual + hidden_states
+
+ return hidden_states, new_conv_cache
+
+
+class MochiUpBlock3D(nn.Module):
+ r"""
+ An upsampling block used in the Mochi model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet blocks in the block.
+ temporal_expansion (`int`, defaults to `2`):
+ Temporal expansion factor.
+ spatial_expansion (`int`, defaults to `2`):
+ Spatial expansion factor.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ temporal_expansion: int = 2,
+ spatial_expansion: int = 2,
+ ):
+ super().__init__()
+ self.temporal_expansion = temporal_expansion
+ self.spatial_expansion = spatial_expansion
+
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(MochiResnetBlock3D(in_channels=in_channels))
+ self.resnets = nn.ModuleList(resnets)
+
+ self.proj = nn.Linear(in_channels, out_channels * temporal_expansion * spatial_expansion**2)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ r"""Forward method of the `MochiUpBlock3D` class."""
+
+ new_conv_cache = {}
+ conv_cache = conv_cache or {}
+
+ for i, resnet in enumerate(self.resnets):
+ conv_cache_key = f"resnet_{i}"
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ resnet,
+ hidden_states,
+ conv_cache=conv_cache.get(conv_cache_key),
+ )
+ else:
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ )
+
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ st = self.temporal_expansion
+ sh = self.spatial_expansion
+ sw = self.spatial_expansion
+
+ # Reshape and unpatchify
+ hidden_states = hidden_states.view(batch_size, -1, st, sh, sw, num_frames, height, width)
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
+ hidden_states = hidden_states.view(batch_size, -1, num_frames * st, height * sh, width * sw)
+
+ return hidden_states, new_conv_cache
+
+
+class FourierFeatures(nn.Module):
+ def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
+ super().__init__()
+
+ self.start = start
+ self.stop = stop
+ self.step = step
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ r"""Forward method of the `FourierFeatures` class."""
+ original_dtype = inputs.dtype
+ inputs = inputs.to(torch.float32)
+ num_channels = inputs.shape[1]
+ num_freqs = (self.stop - self.start) // self.step
+
+ freqs = torch.arange(self.start, self.stop, self.step, dtype=inputs.dtype, device=inputs.device)
+ w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
+ w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
+
+ # Interleaved repeat of input channels to match w
+ h = inputs.repeat_interleave(
+ num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
+ ) # [B, C * num_freqs, T, H, W]
+ # Scale channels by frequency.
+ h = w * h
+
+ return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
+
+
+class MochiEncoder3D(nn.Module):
+ r"""
+ The `MochiEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
+ representation.
+
+ Args:
+ in_channels (`int`, *optional*):
+ The number of input channels.
+ out_channels (`int`, *optional*):
+ The number of output channels.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
+ The number of output channels for each block.
+ layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
+ The number of resnet blocks for each block.
+ temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
+ The temporal expansion factor for each of the up blocks.
+ spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
+ The spatial expansion factor for each of the up blocks.
+ non_linearity (`str`, *optional*, defaults to `"swish"`):
+ The non-linearity to use in the decoder.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
+ add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
+ act_fn: str = "swish",
+ ):
+ super().__init__()
+
+ self.nonlinearity = get_activation(act_fn)
+
+ self.fourier_features = FourierFeatures()
+ self.proj_in = nn.Linear(in_channels, block_out_channels[0])
+ self.block_in = MochiMidBlock3D(
+ in_channels=block_out_channels[0], num_layers=layers_per_block[0], add_attention=add_attention_block[0]
+ )
+
+ down_blocks = []
+ for i in range(len(block_out_channels) - 1):
+ down_block = MochiDownBlock3D(
+ in_channels=block_out_channels[i],
+ out_channels=block_out_channels[i + 1],
+ num_layers=layers_per_block[i + 1],
+ temporal_expansion=temporal_expansions[i],
+ spatial_expansion=spatial_expansions[i],
+ add_attention=add_attention_block[i + 1],
+ )
+ down_blocks.append(down_block)
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.block_out = MochiMidBlock3D(
+ in_channels=block_out_channels[-1], num_layers=layers_per_block[-1], add_attention=add_attention_block[-1]
+ )
+ self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
+ self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)
+
+ def forward(
+ self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
+ ) -> torch.Tensor:
+ r"""Forward method of the `MochiEncoder3D` class."""
+
+ new_conv_cache = {}
+ conv_cache = conv_cache or {}
+
+ hidden_states = self.fourier_features(hidden_states)
+
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
+ self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
+ )
+
+ for i, down_block in enumerate(self.down_blocks):
+ conv_cache_key = f"down_block_{i}"
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ )
+ else:
+ hidden_states, new_conv_cache["block_in"] = self.block_in(
+ hidden_states, conv_cache=conv_cache.get("block_in")
+ )
+
+ for i, down_block in enumerate(self.down_blocks):
+ conv_cache_key = f"down_block_{i}"
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ )
+
+ hidden_states, new_conv_cache["block_out"] = self.block_out(
+ hidden_states, conv_cache=conv_cache.get("block_out")
+ )
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
+
+ return hidden_states, new_conv_cache
+
+
+class MochiDecoder3D(nn.Module):
+ r"""
+ The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
+ sample.
+
+ Args:
+ in_channels (`int`, *optional*):
+ The number of input channels.
+ out_channels (`int`, *optional*):
+ The number of output channels.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
+ The number of output channels for each block.
+ layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
+ The number of resnet blocks for each block.
+ temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
+ The temporal expansion factor for each of the up blocks.
+ spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
+ The spatial expansion factor for each of the up blocks.
+ non_linearity (`str`, *optional*, defaults to `"swish"`):
+ The non-linearity to use in the decoder.
+ """
+
+ def __init__(
+ self,
+ in_channels: int, # 12
+ out_channels: int, # 3
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
+ act_fn: str = "swish",
+ ):
+ super().__init__()
+
+ self.nonlinearity = get_activation(act_fn)
+
+ self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1))
+ self.block_in = MochiMidBlock3D(
+ in_channels=block_out_channels[-1],
+ num_layers=layers_per_block[-1],
+ add_attention=False,
+ )
+
+ up_blocks = []
+ for i in range(len(block_out_channels) - 1):
+ up_block = MochiUpBlock3D(
+ in_channels=block_out_channels[-i - 1],
+ out_channels=block_out_channels[-i - 2],
+ num_layers=layers_per_block[-i - 2],
+ temporal_expansion=temporal_expansions[-i - 1],
+ spatial_expansion=spatial_expansions[-i - 1],
+ )
+ up_blocks.append(up_block)
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ self.block_out = MochiMidBlock3D(
+ in_channels=block_out_channels[0],
+ num_layers=layers_per_block[0],
+ add_attention=False,
+ )
+ self.proj_out = nn.Linear(block_out_channels[0], out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
+ ) -> torch.Tensor:
+ r"""Forward method of the `MochiDecoder3D` class."""
+
+ new_conv_cache = {}
+ conv_cache = conv_cache or {}
+
+ hidden_states = self.conv_in(hidden_states)
+
+ # 1. Mid
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
+ self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
+ )
+
+ for i, up_block in enumerate(self.up_blocks):
+ conv_cache_key = f"up_block_{i}"
+ hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
+ up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ )
+ else:
+ hidden_states, new_conv_cache["block_in"] = self.block_in(
+ hidden_states, conv_cache=conv_cache.get("block_in")
+ )
+
+ for i, up_block in enumerate(self.up_blocks):
+ conv_cache_key = f"up_block_{i}"
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
+ hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ )
+
+ hidden_states, new_conv_cache["block_out"] = self.block_out(
+ hidden_states, conv_cache=conv_cache.get("block_out")
+ )
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
+
+ return hidden_states, new_conv_cache
+
+
+class AutoencoderKLMochi(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
+ [Mochi 1 preview](https://github.com/genmoai/models).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["MochiResnetBlock3D"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 15,
+ out_channels: int = 3,
+ encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
+ decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
+ latent_channels: int = 12,
+ layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
+ act_fn: str = "silu",
+ temporal_expansions: Tuple[int, ...] = (1, 2, 3),
+ spatial_expansions: Tuple[int, ...] = (2, 2, 2),
+ add_attention_block: Tuple[bool, ...] = (False, True, True, True, True),
+ latents_mean: Tuple[float, ...] = (
+ -0.06730895953510081,
+ -0.038011381506090416,
+ -0.07477820912866141,
+ -0.05565264470995561,
+ 0.012767231469026969,
+ -0.04703542746246419,
+ 0.043896967884726704,
+ -0.09346305707025976,
+ -0.09918314763016893,
+ -0.008729793427399178,
+ -0.011931556316503654,
+ -0.0321993391887285,
+ ),
+ latents_std: Tuple[float, ...] = (
+ 0.9263795028493863,
+ 0.9248894543193766,
+ 0.9393059390890617,
+ 0.959253732819592,
+ 0.8244560132752793,
+ 0.917259975397747,
+ 0.9294154431013696,
+ 1.3720942357788521,
+ 0.881393668867029,
+ 0.9168315692124348,
+ 0.9185249279345552,
+ 0.9274757570805041,
+ ),
+ scaling_factor: float = 1.0,
+ ):
+ super().__init__()
+
+ self.encoder = MochiEncoder3D(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ block_out_channels=encoder_block_out_channels,
+ layers_per_block=layers_per_block,
+ temporal_expansions=temporal_expansions,
+ spatial_expansions=spatial_expansions,
+ add_attention_block=add_attention_block,
+ act_fn=act_fn,
+ )
+ self.decoder = MochiDecoder3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=decoder_block_out_channels,
+ layers_per_block=layers_per_block,
+ temporal_expansions=temporal_expansions,
+ spatial_expansions=spatial_expansions,
+ act_fn=act_fn,
+ )
+
+ self.spatial_compression_ratio = functools.reduce(lambda x, y: x * y, spatial_expansions, 1)
+ self.temporal_compression_ratio = functools.reduce(lambda x, y: x * y, temporal_expansions, 1)
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
+ self.use_framewise_encoding = False
+ self.use_framewise_decoding = False
+
+ # This can be used to determine how the number of output frames in the final decoded video. To maintain consistency with
+ # the original implementation, this defaults to `True`.
+ # - Original implementation (drop_last_temporal_frames=True):
+ # Output frames = (latent_frames - 1) * temporal_compression_ratio + 1
+ # - Without dropping additional temporal upscaled frames (drop_last_temporal_frames=False):
+ # Output frames = latent_frames * temporal_compression_ratio
+ # The latter case is useful for frame packing and some training/finetuning scenarios where the additional.
+ self.drop_last_temporal_frames = True
+
+ # This can be configured based on the amount of GPU memory available.
+ # `12` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
+ # Setting it to higher values results in higher memory usage.
+ self.num_sample_frames_batch_size = 12
+ self.num_latent_frames_batch_size = 2
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _enable_framewise_encoding(self):
+ r"""
+ Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
+ oneshot encoding implementation without current latent replicate padding.
+
+ Warning: Framewise encoding may not work as expected due to the causal attention layers. If you enable
+ framewise encoding, encode a video, and try to decode it, there will be noticeable jittering effect.
+ """
+ self.use_framewise_encoding = True
+ for name, module in self.named_modules():
+ if isinstance(module, CogVideoXCausalConv3d):
+ module.pad_mode = "constant"
+
+ def _enable_framewise_decoding(self):
+ r"""
+ Enables the framewise VAE decoding implementation with past latent padding. By default, Diffusers uses the
+ oneshot decoding implementation without current latent replicate padding.
+ """
+ self.use_framewise_decoding = True
+ for name, module in self.named_modules():
+ if isinstance(module, CogVideoXCausalConv3d):
+ module.pad_mode = "constant"
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ if self.use_framewise_encoding:
+ raise NotImplementedError(
+ "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. "
+ "As intermediate frames are not independent from each other, they cannot be encoded frame-wise."
+ )
+ else:
+ enc, _ = self.encoder(x)
+
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ if self.use_framewise_decoding:
+ conv_cache = None
+ dec = []
+
+ for i in range(0, num_frames, self.num_latent_frames_batch_size):
+ z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size]
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
+ dec.append(z_intermediate)
+
+ dec = torch.cat(dec, dim=2)
+ else:
+ dec, _ = self.decoder(z)
+
+ if self.drop_last_temporal_frames and dec.size(2) >= self.temporal_compression_ratio:
+ dec = dec[:, :, self.temporal_compression_ratio - 1 :]
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ if self.use_framewise_encoding:
+ raise NotImplementedError(
+ "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. "
+ "As intermediate frames are not independent from each other, they cannot be encoded frame-wise."
+ )
+ else:
+ time, _ = self.encoder(
+ x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
+ )
+
+ row.append(time)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return enc
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+
+ batch_size, num_channels, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ if self.use_framewise_decoding:
+ time = []
+ conv_cache = None
+
+ for k in range(0, num_frames, self.num_latent_frames_batch_size):
+ tile = z[
+ :,
+ :,
+ k : k + self.num_latent_frames_batch_size,
+ i : i + tile_latent_min_height,
+ j : j + tile_latent_min_width,
+ ]
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
+ time.append(tile)
+
+ time = torch.cat(time, dim=2)
+ else:
+ time, _ = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
+
+ if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio:
+ time = time[:, :, self.temporal_compression_ratio - 1 :]
+
+ row.append(time)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[torch.Tensor, torch.Tensor]:
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ if not return_dict:
+ return (dec,)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
index 55449644ed03..5a72cd395196 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
@@ -94,49 +94,23 @@ def forward(
sample = self.conv_in(sample)
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
+ upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ # middle
+ sample = self._gradient_checkpointing_func(
+ self.mid_block,
+ sample,
+ image_only_indicator,
+ )
+ sample = sample.to(upscale_dtype)
- if is_torch_version(">=", "1.11.0"):
- # middle
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block),
- sample,
- image_only_indicator,
- use_reentrant=False,
- )
- sample = sample.to(upscale_dtype)
-
- # up
- for up_block in self.up_blocks:
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block),
- sample,
- image_only_indicator,
- use_reentrant=False,
- )
- else:
- # middle
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block),
+ # up
+ for up_block in self.up_blocks:
+ sample = self._gradient_checkpointing_func(
+ up_block,
sample,
image_only_indicator,
)
- sample = sample.to(upscale_dtype)
-
- # up
- for up_block in self.up_blocks:
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block),
- sample,
- image_only_indicator,
- )
else:
# middle
sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
@@ -228,18 +202,6 @@ def __init__(
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
- sample_size = (
- self.config.sample_size[0]
- if isinstance(self.config.sample_size, (list, tuple))
- else self.config.sample_size
- )
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
- self.tile_overlap_factor = 0.25
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (Encoder, TemporalDecoder)):
- module.gradient_checkpointing = value
-
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
new file mode 100644
index 000000000000..fafb1fe867e3
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
@@ -0,0 +1,855 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+CACHE_T = 2
+
+
+class WanCausalConv3d(nn.Conv3d):
+ r"""
+ A custom 3D causal convolution layer with feature caching support.
+
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
+ caching for efficient inference.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ padding: Union[int, Tuple[int, int, int]] = 0,
+ ) -> None:
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+
+ # Set up causal padding
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+ return super().forward(x)
+
+
+class WanRMS_norm(nn.Module):
+ r"""
+ A custom RMS normalization layer.
+
+ Args:
+ dim (int): The number of dimensions to normalize over.
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
+ Default is True.
+ images (bool, optional): Whether the input represents image data. Default is True.
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
+ """
+
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class WanUpsample(nn.Upsample):
+ r"""
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
+
+ Args:
+ x (torch.Tensor): Input tensor to be upsampled.
+
+ Returns:
+ torch.Tensor: Upsampled tensor with the same data type as the input.
+ """
+
+ def forward(self, x):
+ return super().forward(x.float()).type_as(x)
+
+
+class WanResample(nn.Module):
+ r"""
+ A custom resampling module for 2D and 3D data.
+
+ Args:
+ dim (int): The number of input/output channels.
+ mode (str): The resampling mode. Must be one of:
+ - 'none': No resampling (identity operation).
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
+ """
+
+ def __init__(self, dim: int, mode: str) -> None:
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == "upsample2d":
+ self.resample = nn.Sequential(
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ )
+ elif mode == "upsample3d":
+ self.resample = nn.Sequential(
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ )
+ self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == "downsample2d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == "downsample3d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == "upsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = "Rep"
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
+ # cache last frame of last two chunk
+ cache_x = torch.cat(
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
+ )
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
+ if feat_cache[idx] == "Rep":
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ x = self.resample(x)
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
+
+ if self.mode == "downsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -1:, :, :].clone()
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+
+class WanResidualBlock(nn.Module):
+ r"""
+ A custom residual block module.
+
+ Args:
+ in_dim (int): Number of input channels.
+ out_dim (int): Number of output channels.
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ dropout: float = 0.0,
+ non_linearity: str = "silu",
+ ) -> None:
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.nonlinearity = get_activation(non_linearity)
+
+ # layers
+ self.norm1 = WanRMS_norm(in_dim, images=False)
+ self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
+ self.norm2 = WanRMS_norm(out_dim, images=False)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
+ self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ # Apply shortcut connection
+ h = self.conv_shortcut(x)
+
+ # First normalization and activation
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ # Second normalization and activation
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+
+ # Dropout
+ x = self.dropout(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv2(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv2(x)
+
+ # Add residual connection
+ return x + h
+
+
+class WanAttentionBlock(nn.Module):
+ r"""
+ Causal self-attention with a single head.
+
+ Args:
+ dim (int): The number of channels in the input tensor.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = WanRMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, x):
+ identity = x
+ batch_size, channels, time, height, width = x.size()
+
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
+ x = self.norm(x)
+
+ # compute query, key, value
+ qkv = self.to_qkv(x)
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
+ q, k, v = qkv.chunk(3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(q, k, v)
+
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
+
+ # output projection
+ x = self.proj(x)
+
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
+ x = x.view(batch_size, time, channels, height, width)
+ x = x.permute(0, 2, 1, 3, 4)
+
+ return x + identity
+
+
+class WanMidBlock(nn.Module):
+ """
+ Middle block for WanVAE encoder and decoder.
+
+ Args:
+ dim (int): Number of input/output channels.
+ dropout (float): Dropout rate.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
+ super().__init__()
+ self.dim = dim
+
+ # Create the components
+ resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
+ attentions = []
+ for _ in range(num_layers):
+ attentions.append(WanAttentionBlock(dim))
+ resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ # First residual block
+ x = self.resnets[0](x, feat_cache, feat_idx)
+
+ # Process through attention and residual blocks
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ x = attn(x)
+
+ x = resnet(x, feat_cache, feat_idx)
+
+ return x
+
+
+class WanEncoder3d(nn.Module):
+ r"""
+ A 3D encoder module.
+
+ Args:
+ dim (int): The base number of channels in the first layer.
+ z_dim (int): The dimensionality of the latent space.
+ dim_mult (list of int): Multipliers for the number of channels in each block.
+ num_res_blocks (int): Number of residual blocks in each block.
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
+ dropout (float): Dropout rate for the dropout layers.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.nonlinearity = get_activation(non_linearity)
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ self.down_blocks = nn.ModuleList([])
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ self.down_blocks.append(WanAttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
+ scale /= 2.0
+
+ # middle blocks
+ self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
+
+ # output blocks
+ self.norm_out = WanRMS_norm(out_dim, images=False)
+ self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_in(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_in(x)
+
+ ## downsamples
+ for layer in self.down_blocks:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ x = self.mid_block(x, feat_cache, feat_idx)
+
+ ## head
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_out(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_out(x)
+ return x
+
+
+class WanUpBlock(nn.Module):
+ """
+ A block that handles upsampling for the WanVAE decoder.
+
+ Args:
+ in_dim (int): Input dimension
+ out_dim (int): Output dimension
+ num_res_blocks (int): Number of residual blocks
+ dropout (float): Dropout rate
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
+ non_linearity (str): Type of non-linearity to use
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ num_res_blocks: int,
+ dropout: float = 0.0,
+ upsample_mode: Optional[str] = None,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # Create layers list
+ resnets = []
+ # Add residual blocks and attention if needed
+ current_dim = in_dim
+ for _ in range(num_res_blocks + 1):
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
+ current_dim = out_dim
+
+ self.resnets = nn.ModuleList(resnets)
+
+ # Add upsampling layer if needed
+ self.upsamplers = None
+ if upsample_mode is not None:
+ self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ """
+ Forward pass through the upsampling block.
+
+ Args:
+ x (torch.Tensor): Input tensor
+ feat_cache (list, optional): Feature cache for causal convolutions
+ feat_idx (list, optional): Feature index for cache management
+
+ Returns:
+ torch.Tensor: Output tensor
+ """
+ for resnet in self.resnets:
+ if feat_cache is not None:
+ x = resnet(x, feat_cache, feat_idx)
+ else:
+ x = resnet(x)
+
+ if self.upsamplers is not None:
+ if feat_cache is not None:
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
+ else:
+ x = self.upsamplers[0](x)
+ return x
+
+
+class WanDecoder3d(nn.Module):
+ r"""
+ A 3D decoder module.
+
+ Args:
+ dim (int): The base number of channels in the first layer.
+ z_dim (int): The dimensionality of the latent space.
+ dim_mult (list of int): Multipliers for the number of channels in each block.
+ num_res_blocks (int): Number of residual blocks in each block.
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
+ dropout (float): Dropout rate for the dropout layers.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
+
+ # init block
+ self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
+
+ # upsample blocks
+ self.up_blocks = nn.ModuleList([])
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i > 0:
+ in_dim = in_dim // 2
+
+ # Determine if we need upsampling
+ upsample_mode = None
+ if i != len(dim_mult) - 1:
+ upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
+
+ # Create and add the upsampling block
+ up_block = WanUpBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ num_res_blocks=num_res_blocks,
+ dropout=dropout,
+ upsample_mode=upsample_mode,
+ non_linearity=non_linearity,
+ )
+ self.up_blocks.append(up_block)
+
+ # Update scale for next iteration
+ if upsample_mode is not None:
+ scale *= 2.0
+
+ # output blocks
+ self.norm_out = WanRMS_norm(out_dim, images=False)
+ self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_in(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_in(x)
+
+ ## middle
+ x = self.mid_block(x, feat_cache, feat_idx)
+
+ ## upsamples
+ for up_block in self.up_blocks:
+ x = up_block(x, feat_cache, feat_idx)
+
+ ## head
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_out(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_out(x)
+ return x
+
+
+class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
+ Introduced in [Wan 2.1].
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = False
+
+ @register_to_config
+ def __init__(
+ self,
+ base_dim: int = 96,
+ z_dim: int = 16,
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
+ num_res_blocks: int = 2,
+ attn_scales: List[float] = [],
+ temperal_downsample: List[bool] = [False, True, True],
+ dropout: float = 0.0,
+ latents_mean: List[float] = [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921,
+ ],
+ latents_std: List[float] = [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.9160,
+ ],
+ ) -> None:
+ super().__init__()
+
+ self.z_dim = z_dim
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ self.encoder = WanEncoder3d(
+ base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
+ )
+ self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
+
+ self.decoder = WanDecoder3d(
+ base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
+ )
+
+ def clear_cache(self):
+ def _count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, WanCausalConv3d):
+ count += 1
+ return count
+
+ self._conv_num = _count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ # cache encode
+ self._enc_conv_num = _count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ self.clear_cache()
+ ## cache
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx,
+ )
+ out = torch.cat([out, out_], 2)
+
+ enc = self.quant_conv(out)
+ mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
+ enc = torch.cat([mu, logvar], dim=1)
+ self.clear_cache()
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ self.clear_cache()
+
+ iter_ = z.shape[2]
+ x = self.post_quant_conv(z)
+ for i in range(iter_):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+
+ out = torch.clamp(out, min=-1.0, max=1.0)
+ self.clear_cache()
+ if not return_dict:
+ return (out,)
+
+ return DecoderOutput(sample=out)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ decoded = self._decode(z).sample
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/src/diffusers/models/autoencoders/autoencoder_oobleck.py
index e8e372a709d7..a8c2a2fd3840 100644
--- a/src/diffusers/models/autoencoders/autoencoder_oobleck.py
+++ b/src/diffusers/models/autoencoders/autoencoder_oobleck.py
@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = False
+ _supports_group_offloading = False
@register_to_config
def __init__(
diff --git a/src/diffusers/models/autoencoders/autoencoder_tiny.py b/src/diffusers/models/autoencoders/autoencoder_tiny.py
index 6e503478fe2b..7ed727c55c37 100644
--- a/src/diffusers/models/autoencoders/autoencoder_tiny.py
+++ b/src/diffusers/models/autoencoders/autoencoder_tiny.py
@@ -154,10 +154,6 @@ def __init__(
self.register_to_config(block_out_channels=decoder_block_out_channels)
self.register_to_config(force_upcast=False)
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- if isinstance(module, (EncoderTiny, DecoderTiny)):
- module.gradient_checkpointing = value
-
def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
@@ -310,7 +306,9 @@ def decode(
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
if self.use_slicing and x.shape[0] > 1:
- output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
+ output = [
+ self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
+ ]
output = torch.cat(output)
else:
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
@@ -341,7 +339,7 @@ def forward(
# as if we were loading the latents from an RGBA uint8 image.
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
- dec = self.decode(unscaled_enc)
+ dec = self.decode(unscaled_enc).sample
if not return_dict:
return (dec,)
diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py
index a97249f79473..a0b3309dc522 100644
--- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py
+++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py
@@ -60,7 +60,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
>>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
>>> pipe = StableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
... ).to("cuda")
>>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
@@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
```
"""
+ _supports_group_offloading = False
+
@register_to_config
def __init__(
self,
diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py
index bb80ce8605ba..72e0acda3afe 100644
--- a/src/diffusers/models/autoencoders/vae.py
+++ b/src/diffusers/models/autoencoders/vae.py
@@ -18,7 +18,7 @@
import torch
import torch.nn as nn
-from ...utils import BaseOutput, is_torch_version
+from ...utils import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
@@ -30,6 +30,19 @@
)
+@dataclass
+class EncoderOutput(BaseOutput):
+ r"""
+ Output of encoding method.
+
+ Args:
+ latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`):
+ The encoded latent.
+ """
+
+ latent: torch.Tensor
+
+
@dataclass
class DecoderOutput(BaseOutput):
r"""
@@ -142,29 +155,12 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.conv_in(sample)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
# down
- if is_torch_version(">=", "1.11.0"):
- for down_block in self.down_blocks:
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(down_block), sample, use_reentrant=False
- )
- # middle
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block), sample, use_reentrant=False
- )
- else:
- for down_block in self.down_blocks:
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
- # middle
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
+ for down_block in self.down_blocks:
+ sample = self._gradient_checkpointing_func(down_block, sample)
+ # middle
+ sample = self._gradient_checkpointing_func(self.mid_block, sample)
else:
# down
@@ -291,42 +287,14 @@ def forward(
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- # middle
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block),
- sample,
- latent_embeds,
- use_reentrant=False,
- )
- sample = sample.to(upscale_dtype)
-
- # up
- for up_block in self.up_blocks:
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block),
- sample,
- latent_embeds,
- use_reentrant=False,
- )
- else:
- # middle
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block), sample, latent_embeds
- )
- sample = sample.to(upscale_dtype)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ # middle
+ sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
- # up
- for up_block in self.up_blocks:
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
+ # up
+ for up_block in self.up_blocks:
+ sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
@@ -544,73 +512,29 @@ def forward(
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- # middle
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block),
- sample,
- latent_embeds,
- use_reentrant=False,
- )
- sample = sample.to(upscale_dtype)
-
- # condition encoder
- if image is not None and mask is not None:
- masked_image = (1 - mask) * image
- im_x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.condition_encoder),
- masked_image,
- mask,
- use_reentrant=False,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ # middle
+ sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
- # up
- for up_block in self.up_blocks:
- if image is not None and mask is not None:
- sample_ = im_x[str(tuple(sample.shape))]
- mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
- sample = sample * mask_ + sample_ * (1 - mask_)
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block),
- sample,
- latent_embeds,
- use_reentrant=False,
- )
- if image is not None and mask is not None:
- sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
- else:
- # middle
- sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block), sample, latent_embeds
+ # condition encoder
+ if image is not None and mask is not None:
+ masked_image = (1 - mask) * image
+ im_x = self._gradient_checkpointing_func(
+ self.condition_encoder,
+ masked_image,
+ mask,
)
- sample = sample.to(upscale_dtype)
- # condition encoder
- if image is not None and mask is not None:
- masked_image = (1 - mask) * image
- im_x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.condition_encoder),
- masked_image,
- mask,
- )
-
- # up
- for up_block in self.up_blocks:
- if image is not None and mask is not None:
- sample_ = im_x[str(tuple(sample.shape))]
- mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
- sample = sample * mask_ + sample_ * (1 - mask_)
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
+ # up
+ for up_block in self.up_blocks:
if image is not None and mask is not None:
- sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
+ sample_ = im_x[str(tuple(sample.shape))]
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
+ sample = sample * mask_ + sample_ * (1 - mask_)
+ sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
+ if image is not None and mask is not None:
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
@@ -876,18 +800,8 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class."""
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
- else:
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ x = self._gradient_checkpointing_func(self.layers, x)
else:
# scale image from [-1, 1] to [0, 1] to match TAESD convention
@@ -962,19 +876,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Clamp.
x = torch.tanh(x / 3) * 3
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
- else:
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
-
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ x = self._gradient_checkpointing_func(self.layers, x)
else:
x = self.layers(x)
diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py
index ae8a118d719a..84215389bf6a 100644
--- a/src/diffusers/models/autoencoders/vq_model.py
+++ b/src/diffusers/models/autoencoders/vq_model.py
@@ -71,6 +71,9 @@ class VQModel(ModelMixin, ConfigMixin):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
"""
+ _skip_layerwise_casting_patterns = ["quantize"]
+ _supports_group_offloading = False
+
@register_to_config
def __init__(
self,
diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py
new file mode 100644
index 000000000000..79bd8dc0b254
--- /dev/null
+++ b/src/diffusers/models/cache_utils.py
@@ -0,0 +1,108 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..utils.logging import get_logger
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CacheMixin:
+ r"""
+ A class for enable/disabling caching techniques on diffusion models.
+
+ Supported caching techniques:
+ - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
+ - [FasterCache](https://huggingface.co/papers/2410.19355)
+ """
+
+ _cache_config = None
+
+ @property
+ def is_cache_enabled(self) -> bool:
+ return self._cache_config is not None
+
+ def enable_cache(self, config) -> None:
+ r"""
+ Enable caching techniques on the model.
+
+ Args:
+ config (`Union[PyramidAttentionBroadcastConfig]`):
+ The configuration for applying the caching technique. Currently supported caching techniques are:
+ - [`~hooks.PyramidAttentionBroadcastConfig`]
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
+
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> config = PyramidAttentionBroadcastConfig(
+ ... spatial_attention_block_skip_range=2,
+ ... spatial_attention_timestep_skip_range=(100, 800),
+ ... current_timestep_callback=lambda: pipe.current_timestep,
+ ... )
+ >>> pipe.transformer.enable_cache(config)
+ ```
+ """
+
+ from ..hooks import (
+ FasterCacheConfig,
+ PyramidAttentionBroadcastConfig,
+ apply_faster_cache,
+ apply_pyramid_attention_broadcast,
+ )
+
+ if self.is_cache_enabled:
+ raise ValueError(
+ f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
+ )
+
+ if isinstance(config, PyramidAttentionBroadcastConfig):
+ apply_pyramid_attention_broadcast(self, config)
+ elif isinstance(config, FasterCacheConfig):
+ apply_faster_cache(self, config)
+ else:
+ raise ValueError(f"Cache config {type(config)} is not supported.")
+
+ self._cache_config = config
+
+ def disable_cache(self) -> None:
+ from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
+ from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
+ from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
+
+ if self._cache_config is None:
+ logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
+ return
+
+ if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
+ elif isinstance(self._cache_config, FasterCacheConfig):
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
+ registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
+ else:
+ raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
+
+ self._cache_config = None
+
+ def _reset_stateful_cache(self, recurse: bool = True) -> None:
+ from ..hooks import HookRegistry
+
+ HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py
index d3ae96605077..b9ebab818be7 100644
--- a/src/diffusers/models/controlnet.py
+++ b/src/diffusers/models/controlnet.py
@@ -11,175 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..loaders.single_file_model import FromOriginalModelMixin
-from ..utils import BaseOutput, logging
-from .attention_processor import (
- ADDED_KV_ATTENTION_PROCESSORS,
- CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
- AttnAddedKVProcessor,
- AttnProcessor,
-)
-from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
-from .modeling_utils import ModelMixin
-from .unets.unet_2d_blocks import (
- CrossAttnDownBlock2D,
- DownBlock2D,
- UNetMidBlock2D,
- UNetMidBlock2DCrossAttn,
- get_down_block,
+from typing import Optional, Tuple, Union
+
+from ..utils import deprecate
+from .controlnets.controlnet import ( # noqa
+ ControlNetConditioningEmbedding,
+ ControlNetModel,
+ ControlNetOutput,
+ zero_module,
)
-from .unets.unet_2d_condition import UNet2DConditionModel
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-@dataclass
-class ControlNetOutput(BaseOutput):
- """
- The output of [`ControlNetModel`].
+class ControlNetOutput(ControlNetOutput):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
+ deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
- Args:
- down_block_res_samples (`tuple[torch.Tensor]`):
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
- used to condition the original UNet's downsampling activations.
- mid_down_block_re_sample (`torch.Tensor`):
- The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
- Output can be used to condition the original UNet's middle block activation.
- """
-
- down_block_res_samples: Tuple[torch.Tensor]
- mid_block_res_sample: torch.Tensor
-
-
-class ControlNetConditioningEmbedding(nn.Module):
- """
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
- model) to encode image-space conditions ... into feature maps ..."
- """
- def __init__(
- self,
- conditioning_embedding_channels: int,
- conditioning_channels: int = 3,
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
- ):
- super().__init__()
-
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
-
- self.blocks = nn.ModuleList([])
-
- for i in range(len(block_out_channels) - 1):
- channel_in = block_out_channels[i]
- channel_out = block_out_channels[i + 1]
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
-
- self.conv_out = zero_module(
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
- )
-
- def forward(self, conditioning):
- embedding = self.conv_in(conditioning)
- embedding = F.silu(embedding)
-
- for block in self.blocks:
- embedding = block(embedding)
- embedding = F.silu(embedding)
-
- embedding = self.conv_out(embedding)
-
- return embedding
-
-
-class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
- """
- A ControlNet model.
-
- Args:
- in_channels (`int`, defaults to 4):
- The number of channels in the input sample.
- flip_sin_to_cos (`bool`, defaults to `True`):
- Whether to flip the sin to cos in the time embedding.
- freq_shift (`int`, defaults to 0):
- The frequency shift to apply to the time embedding.
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
- The tuple of downsample blocks to use.
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
- The tuple of output channels for each block.
- layers_per_block (`int`, defaults to 2):
- The number of layers per block.
- downsample_padding (`int`, defaults to 1):
- The padding to use for the downsampling convolution.
- mid_block_scale_factor (`float`, defaults to 1):
- The scale factor to use for the mid block.
- act_fn (`str`, defaults to "silu"):
- The activation function to use.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
- in post-processing.
- norm_eps (`float`, defaults to 1e-5):
- The epsilon to use for the normalization.
- cross_attention_dim (`int`, defaults to 1280):
- The dimension of the cross attention features.
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
- encoder_hid_dim (`int`, *optional*, defaults to None):
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
- dimension to `cross_attention_dim`.
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
- attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
- The dimension of the attention heads.
- use_linear_projection (`bool`, defaults to `False`):
- class_embed_type (`str`, *optional*, defaults to `None`):
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
- addition_embed_type (`str`, *optional*, defaults to `None`):
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
- "text". "text" will use the `TextTimeEmbedding` layer.
- num_class_embeds (`int`, *optional*, defaults to 0):
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
- class conditioning with `class_embed_type` equal to `None`.
- upcast_attention (`bool`, defaults to `False`):
- resnet_time_scale_shift (`str`, defaults to `"default"`):
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
- projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
- The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
- `class_embed_type="projection"`.
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
- conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
- The tuple of output channel for each block in the `conditioning_embedding` layer.
- global_pool_conditions (`bool`, defaults to `False`):
- TODO(Patrick) - unused parameter.
- addition_embed_type_num_heads (`int`, defaults to 64):
- The number of heads to use for the `TextTimeEmbedding` layer.
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
+class ControlNetModel(ControlNetModel):
def __init__(
self,
in_channels: int = 4,
@@ -220,651 +70,46 @@ def __init__(
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
- super().__init__()
-
- # If `num_attention_heads` is not defined (which is the case for most models)
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
- # which is why we correct for the naming here.
- num_attention_heads = num_attention_heads or attention_head_dim
-
- # Check inputs
- if len(block_out_channels) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
- )
-
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
-
- # input
- conv_in_kernel = 3
- conv_in_padding = (conv_in_kernel - 1) // 2
- self.conv_in = nn.Conv2d(
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
- )
-
- # time
- time_embed_dim = block_out_channels[0] * 4
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
- timestep_input_dim = block_out_channels[0]
- self.time_embedding = TimestepEmbedding(
- timestep_input_dim,
- time_embed_dim,
- act_fn=act_fn,
- )
-
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
- encoder_hid_dim_type = "text_proj"
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
-
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
- raise ValueError(
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
- )
-
- if encoder_hid_dim_type == "text_proj":
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
- elif encoder_hid_dim_type == "text_image_proj":
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
- # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
- self.encoder_hid_proj = TextImageProjection(
- text_embed_dim=encoder_hid_dim,
- image_embed_dim=cross_attention_dim,
- cross_attention_dim=cross_attention_dim,
- )
-
- elif encoder_hid_dim_type is not None:
- raise ValueError(
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
- )
- else:
- self.encoder_hid_proj = None
-
- # class embedding
- if class_embed_type is None and num_class_embeds is not None:
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
- elif class_embed_type == "timestep":
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
- elif class_embed_type == "identity":
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
- elif class_embed_type == "projection":
- if projection_class_embeddings_input_dim is None:
- raise ValueError(
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
- )
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
- # 2. it projects from an arbitrary input dimension.
- #
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
- else:
- self.class_embedding = None
-
- if addition_embed_type == "text":
- if encoder_hid_dim is not None:
- text_time_embedding_from_dim = encoder_hid_dim
- else:
- text_time_embedding_from_dim = cross_attention_dim
-
- self.add_embedding = TextTimeEmbedding(
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
- )
- elif addition_embed_type == "text_image":
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
- # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
- self.add_embedding = TextImageTimeEmbedding(
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
- )
- elif addition_embed_type == "text_time":
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
-
- elif addition_embed_type is not None:
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
-
- # control net conditioning embedding
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
- conditioning_embedding_channels=block_out_channels[0],
- block_out_channels=conditioning_embedding_out_channels,
+ deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
+ deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
+ super().__init__(
+ in_channels=in_channels,
conditioning_channels=conditioning_channels,
- )
-
- self.down_blocks = nn.ModuleList([])
- self.controlnet_down_blocks = nn.ModuleList([])
-
- if isinstance(only_cross_attention, bool):
- only_cross_attention = [only_cross_attention] * len(down_block_types)
-
- if isinstance(attention_head_dim, int):
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
-
- if isinstance(num_attention_heads, int):
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
-
- # down
- output_channel = block_out_channels[0]
-
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- for i, down_block_type in enumerate(down_block_types):
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
-
- down_block = get_down_block(
- down_block_type,
- num_layers=layers_per_block,
- transformer_layers_per_block=transformer_layers_per_block[i],
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=time_embed_dim,
- add_downsample=not is_final_block,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads[i],
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
- downsample_padding=downsample_padding,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- )
- self.down_blocks.append(down_block)
-
- for _ in range(layers_per_block):
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- if not is_final_block:
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- # mid
- mid_block_channel = block_out_channels[-1]
-
- controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_mid_block = controlnet_block
-
- if mid_block_type == "UNetMidBlock2DCrossAttn":
- self.mid_block = UNetMidBlock2DCrossAttn(
- transformer_layers_per_block=transformer_layers_per_block[-1],
- in_channels=mid_block_channel,
- temb_channels=time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads[-1],
- resnet_groups=norm_num_groups,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- )
- elif mid_block_type == "UNetMidBlock2D":
- self.mid_block = UNetMidBlock2D(
- in_channels=block_out_channels[-1],
- temb_channels=time_embed_dim,
- num_layers=0,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_groups=norm_num_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- add_attention=False,
- )
- else:
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
-
- @classmethod
- def from_unet(
- cls,
- unet: UNet2DConditionModel,
- controlnet_conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
- load_weights_from_unet: bool = True,
- conditioning_channels: int = 3,
- ):
- r"""
- Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
-
- Parameters:
- unet (`UNet2DConditionModel`):
- The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
- where applicable.
- """
- transformer_layers_per_block = (
- unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
- )
- encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
- encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
- addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
- addition_time_embed_dim = (
- unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
- )
-
- controlnet = cls(
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ down_block_types=down_block_types,
+ mid_block_type=mid_block_type,
+ only_cross_attention=only_cross_attention,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ downsample_padding=downsample_padding,
+ mid_block_scale_factor=mid_block_scale_factor,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ cross_attention_dim=cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ use_linear_projection=use_linear_projection,
+ class_embed_type=class_embed_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
- transformer_layers_per_block=transformer_layers_per_block,
- in_channels=unet.config.in_channels,
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
- freq_shift=unet.config.freq_shift,
- down_block_types=unet.config.down_block_types,
- only_cross_attention=unet.config.only_cross_attention,
- block_out_channels=unet.config.block_out_channels,
- layers_per_block=unet.config.layers_per_block,
- downsample_padding=unet.config.downsample_padding,
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
- act_fn=unet.config.act_fn,
- norm_num_groups=unet.config.norm_num_groups,
- norm_eps=unet.config.norm_eps,
- cross_attention_dim=unet.config.cross_attention_dim,
- attention_head_dim=unet.config.attention_head_dim,
- num_attention_heads=unet.config.num_attention_heads,
- use_linear_projection=unet.config.use_linear_projection,
- class_embed_type=unet.config.class_embed_type,
- num_class_embeds=unet.config.num_class_embeds,
- upcast_attention=unet.config.upcast_attention,
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
- mid_block_type=unet.config.mid_block_type,
+ num_class_embeds=num_class_embeds,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
- conditioning_channels=conditioning_channels,
- )
-
- if load_weights_from_unet:
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
-
- if controlnet.class_embedding:
- controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
-
- if hasattr(controlnet, "add_embedding"):
- controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
-
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
-
- return controlnet
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
- def set_default_attn_processor(self):
- """
- Disables custom attention processors and sets the default attention implementation.
- """
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnAddedKVProcessor()
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnProcessor()
- else:
- raise ValueError(
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
- )
-
- self.set_attn_processor(processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
- r"""
- Enable sliced attention computation.
-
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
-
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- sliceable_head_dims = []
-
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
- if hasattr(module, "set_attention_slice"):
- sliceable_head_dims.append(module.sliceable_head_dim)
-
- for child in module.children():
- fn_recursive_retrieve_sliceable_dims(child)
-
- # retrieve number of attention layers
- for module in self.children():
- fn_recursive_retrieve_sliceable_dims(module)
-
- num_sliceable_layers = len(sliceable_head_dims)
-
- if slice_size == "auto":
- # half the attention head size is usually a good trade-off between
- # speed and memory
- slice_size = [dim // 2 for dim in sliceable_head_dims]
- elif slice_size == "max":
- # make smallest slice possible
- slice_size = num_sliceable_layers * [1]
-
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
-
- if len(slice_size) != len(sliceable_head_dims):
- raise ValueError(
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
- )
-
- for i in range(len(slice_size)):
- size = slice_size[i]
- dim = sliceable_head_dims[i]
- if size is not None and size > dim:
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
-
- # Recursively walk through all the children.
- # Any children which exposes the set_attention_slice method
- # gets the message
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
- if hasattr(module, "set_attention_slice"):
- module.set_attention_slice(slice_size.pop())
-
- for child in module.children():
- fn_recursive_set_attention_slice(child, slice_size)
-
- reversed_slice_size = list(reversed(slice_size))
- for module in self.children():
- fn_recursive_set_attention_slice(module, reversed_slice_size)
-
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
- module.gradient_checkpointing = value
-
- def forward(
- self,
- sample: torch.Tensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- controlnet_cond: torch.Tensor,
- conditioning_scale: float = 1.0,
- class_labels: Optional[torch.Tensor] = None,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- guess_mode: bool = False,
- return_dict: bool = True,
- ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
- """
- The [`ControlNetModel`] forward method.
-
- Args:
- sample (`torch.Tensor`):
- The noisy input tensor.
- timestep (`Union[torch.Tensor, float, int]`):
- The number of timesteps to denoise an input.
- encoder_hidden_states (`torch.Tensor`):
- The encoder hidden states.
- controlnet_cond (`torch.Tensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- conditioning_scale (`float`, defaults to `1.0`):
- The scale factor for ControlNet outputs.
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
- embeddings.
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
- negative values to the attention scores corresponding to "discard" tokens.
- added_cond_kwargs (`dict`):
- Additional conditions for the Stable Diffusion XL UNet.
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
- guess_mode (`bool`, defaults to `False`):
- In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
- return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
-
- Returns:
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
- returned where the first element is the sample tensor.
- """
- # check channel order
- channel_order = self.config.controlnet_conditioning_channel_order
-
- if channel_order == "rgb":
- # in rgb order by default
- ...
- elif channel_order == "bgr":
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
- else:
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
-
- # prepare attention_mask
- if attention_mask is not None:
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # 1. time
- timesteps = timestep
- if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
- else:
- dtype = torch.int32 if is_mps else torch.int64
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
-
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=sample.dtype)
-
- emb = self.time_embedding(t_emb, timestep_cond)
- aug_emb = None
-
- if self.class_embedding is not None:
- if class_labels is None:
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
-
- if self.config.class_embed_type == "timestep":
- class_labels = self.time_proj(class_labels)
-
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
- emb = emb + class_emb
-
- if self.config.addition_embed_type is not None:
- if self.config.addition_embed_type == "text":
- aug_emb = self.add_embedding(encoder_hidden_states)
-
- elif self.config.addition_embed_type == "text_time":
- if "text_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
- )
- text_embeds = added_cond_kwargs.get("text_embeds")
- if "time_ids" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
- )
- time_ids = added_cond_kwargs.get("time_ids")
- time_embeds = self.add_time_proj(time_ids.flatten())
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
-
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
- add_embeds = add_embeds.to(emb.dtype)
- aug_emb = self.add_embedding(add_embeds)
-
- emb = emb + aug_emb if aug_emb is not None else emb
-
- # 2. pre-process
- sample = self.conv_in(sample)
-
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
- sample = sample + controlnet_cond
-
- # 3. down
- down_block_res_samples = (sample,)
- for downsample_block in self.down_blocks:
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
- sample, res_samples = downsample_block(
- hidden_states=sample,
- temb=emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
-
- down_block_res_samples += res_samples
-
- # 4. mid
- if self.mid_block is not None:
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample = self.mid_block(sample, emb)
-
- # 5. Control net blocks
- controlnet_down_block_res_samples = ()
-
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
- down_block_res_sample = controlnet_block(down_block_res_sample)
- controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
-
- down_block_res_samples = controlnet_down_block_res_samples
-
- mid_block_res_sample = self.controlnet_mid_block(sample)
-
- # 6. scaling
- if guess_mode and not self.config.global_pool_conditions:
- scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
- scales = scales * conditioning_scale
- down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
- mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
- else:
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
-
- if self.config.global_pool_conditions:
- down_block_res_samples = [
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
- ]
- mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
-
- if not return_dict:
- return (down_block_res_samples, mid_block_res_sample)
-
- return ControlNetOutput(
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ global_pool_conditions=global_pool_conditions,
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
)
-def zero_module(module):
- for p in module.parameters():
- nn.init.zeros_(p)
- return module
+class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
+ deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py
index 961e30155a3d..2035deb1062d 100644
--- a/src/diffusers/models/controlnet_flux.py
+++ b/src/diffusers/models/controlnet_flux.py
@@ -12,36 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-import torch
-import torch.nn as nn
+from typing import List
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..loaders import PeftAdapterMixin
-from ..models.attention_processor import AttentionProcessor
-from ..models.modeling_utils import ModelMixin
-from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
-from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
-from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
-from .modeling_outputs import Transformer2DModelOutput
-from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+from ..utils import deprecate, logging
+from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-@dataclass
-class FluxControlNetOutput(BaseOutput):
- controlnet_block_samples: Tuple[torch.Tensor]
- controlnet_single_block_samples: Tuple[torch.Tensor]
+class FluxControlNetOutput(FluxControlNetOutput):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
+ deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
-class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
- _supports_gradient_checkpointing = True
-
- @register_to_config
+class FluxControlNetModel(FluxControlNetModel):
def __init__(
self,
patch_size: int = 1,
@@ -57,480 +45,26 @@ def __init__(
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
- super().__init__()
- self.out_channels = in_channels
- self.inner_dim = num_attention_heads * attention_head_dim
-
- self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
- text_time_guidance_cls = (
- CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
- )
- self.time_text_embed = text_time_guidance_cls(
- embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
- )
-
- self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
- self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
-
- self.transformer_blocks = nn.ModuleList(
- [
- FluxTransformerBlock(
- dim=self.inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- )
- for i in range(num_layers)
- ]
- )
-
- self.single_transformer_blocks = nn.ModuleList(
- [
- FluxSingleTransformerBlock(
- dim=self.inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=attention_head_dim,
- )
- for i in range(num_single_layers)
- ]
- )
-
- # controlnet_blocks
- self.controlnet_blocks = nn.ModuleList([])
- for _ in range(len(self.transformer_blocks)):
- self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
-
- self.controlnet_single_blocks = nn.ModuleList([])
- for _ in range(len(self.single_transformer_blocks)):
- self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
-
- self.union = num_mode is not None
- if self.union:
- self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
-
- if conditioning_embedding_channels is not None:
- self.input_hint_block = ControlNetConditioningEmbedding(
- conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
- )
- self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
- else:
- self.input_hint_block = None
- self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
-
- self.gradient_checkpointing = False
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self):
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
- @classmethod
- def from_transformer(
- cls,
- transformer,
- num_layers: int = 4,
- num_single_layers: int = 10,
- attention_head_dim: int = 128,
- num_attention_heads: int = 24,
- load_weights_from_transformer=True,
- ):
- config = transformer.config
- config["num_layers"] = num_layers
- config["num_single_layers"] = num_single_layers
- config["attention_head_dim"] = attention_head_dim
- config["num_attention_heads"] = num_attention_heads
-
- controlnet = cls(**config)
-
- if load_weights_from_transformer:
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
- controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
- controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
- controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
- controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
- controlnet.single_transformer_blocks.load_state_dict(
- transformer.single_transformer_blocks.state_dict(), strict=False
- )
-
- controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
-
- return controlnet
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- controlnet_cond: torch.Tensor,
- controlnet_mode: torch.Tensor = None,
- conditioning_scale: float = 1.0,
- encoder_hidden_states: torch.Tensor = None,
- pooled_projections: torch.Tensor = None,
- timestep: torch.LongTensor = None,
- img_ids: torch.Tensor = None,
- txt_ids: torch.Tensor = None,
- guidance: torch.Tensor = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
- """
- The [`FluxTransformer2DModel`] forward method.
-
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
- Input `hidden_states`.
- controlnet_cond (`torch.Tensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- controlnet_mode (`torch.Tensor`):
- The mode tensor of shape `(batch_size, 1)`.
- conditioning_scale (`float`, defaults to `1.0`):
- The scale factor for ControlNet outputs.
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
- from the embeddings of input conditions.
- timestep ( `torch.LongTensor`):
- Used to indicate denoising step.
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
- A list of tensors that if specified are added to the residuals of transformer blocks.
- joint_attention_kwargs (`dict`, *optional*):
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
- `self.processor` in
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
- tuple.
-
- Returns:
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
- `tuple` where the first element is the sample tensor.
- """
- if joint_attention_kwargs is not None:
- joint_attention_kwargs = joint_attention_kwargs.copy()
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
- else:
- lora_scale = 1.0
-
- if USE_PEFT_BACKEND:
- # weight the lora layers by setting `lora_scale` for each PEFT layer
- scale_lora_layers(self, lora_scale)
- else:
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
- logger.warning(
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
- )
- hidden_states = self.x_embedder(hidden_states)
-
- if self.input_hint_block is not None:
- controlnet_cond = self.input_hint_block(controlnet_cond)
- batch_size, channels, height_pw, width_pw = controlnet_cond.shape
- height = height_pw // self.config.patch_size
- width = width_pw // self.config.patch_size
- controlnet_cond = controlnet_cond.reshape(
- batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
- )
- controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
- controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
- # add
- hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
-
- timestep = timestep.to(hidden_states.dtype) * 1000
- if guidance is not None:
- guidance = guidance.to(hidden_states.dtype) * 1000
- else:
- guidance = None
- temb = (
- self.time_text_embed(timestep, pooled_projections)
- if guidance is None
- else self.time_text_embed(timestep, guidance, pooled_projections)
+ deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
+ deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
+ super().__init__(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ joint_attention_dim=joint_attention_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ guidance_embeds=guidance_embeds,
+ axes_dims_rope=axes_dims_rope,
+ num_mode=num_mode,
+ conditioning_embedding_channels=conditioning_embedding_channels,
)
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
-
- if self.union:
- # union mode
- if controlnet_mode is None:
- raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
- # union mode emb
- controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
- encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
- txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
-
- if txt_ids.ndim == 3:
- logger.warning(
- "Passing `txt_ids` 3d torch.Tensor is deprecated."
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
- )
- txt_ids = txt_ids[0]
- if img_ids.ndim == 3:
- logger.warning(
- "Passing `img_ids` 3d torch.Tensor is deprecated."
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
- )
- img_ids = img_ids[0]
-
- ids = torch.cat((txt_ids, img_ids), dim=0)
- image_rotary_emb = self.pos_embed(ids)
-
- block_samples = ()
- for index_block, block in enumerate(self.transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- encoder_hidden_states,
- temb,
- image_rotary_emb,
- **ckpt_kwargs,
- )
-
- else:
- encoder_hidden_states, hidden_states = block(
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- temb=temb,
- image_rotary_emb=image_rotary_emb,
- )
- block_samples = block_samples + (hidden_states,)
-
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
- single_block_samples = ()
- for index_block, block in enumerate(self.single_transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- temb,
- image_rotary_emb,
- **ckpt_kwargs,
- )
-
- else:
- hidden_states = block(
- hidden_states=hidden_states,
- temb=temb,
- image_rotary_emb=image_rotary_emb,
- )
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
-
- # controlnet block
- controlnet_block_samples = ()
- for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
- block_sample = controlnet_block(block_sample)
- controlnet_block_samples = controlnet_block_samples + (block_sample,)
-
- controlnet_single_block_samples = ()
- for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
- single_block_sample = controlnet_block(single_block_sample)
- controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
-
- # scaling
- controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
- controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
-
- controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
- controlnet_single_block_samples = (
- None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
- )
-
- if USE_PEFT_BACKEND:
- # remove `lora_scale` from each PEFT layer
- unscale_lora_layers(self, lora_scale)
-
- if not return_dict:
- return (controlnet_block_samples, controlnet_single_block_samples)
-
- return FluxControlNetOutput(
- controlnet_block_samples=controlnet_block_samples,
- controlnet_single_block_samples=controlnet_single_block_samples,
- )
-
-
-class FluxMultiControlNetModel(ModelMixin):
- r"""
- `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
-
- This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
- compatible with `FluxControlNetModel`.
-
- Args:
- controlnets (`List[FluxControlNetModel]`):
- Provides additional conditioning to the unet during the denoising process. You must set multiple
- `FluxControlNetModel` as a list.
- """
-
- def __init__(self, controlnets):
- super().__init__()
- self.nets = nn.ModuleList(controlnets)
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- controlnet_cond: List[torch.tensor],
- controlnet_mode: List[torch.tensor],
- conditioning_scale: List[float],
- encoder_hidden_states: torch.Tensor = None,
- pooled_projections: torch.Tensor = None,
- timestep: torch.LongTensor = None,
- img_ids: torch.Tensor = None,
- txt_ids: torch.Tensor = None,
- guidance: torch.Tensor = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[FluxControlNetOutput, Tuple]:
- # ControlNet-Union with multiple conditions
- # only load one ControlNet for saving memories
- if len(self.nets) == 1 and self.nets[0].union:
- controlnet = self.nets[0]
-
- for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
- block_samples, single_block_samples = controlnet(
- hidden_states=hidden_states,
- controlnet_cond=image,
- controlnet_mode=mode[:, None],
- conditioning_scale=scale,
- timestep=timestep,
- guidance=guidance,
- pooled_projections=pooled_projections,
- encoder_hidden_states=encoder_hidden_states,
- txt_ids=txt_ids,
- img_ids=img_ids,
- joint_attention_kwargs=joint_attention_kwargs,
- return_dict=return_dict,
- )
-
- # merge samples
- if i == 0:
- control_block_samples = block_samples
- control_single_block_samples = single_block_samples
- else:
- control_block_samples = [
- control_block_sample + block_sample
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
- ]
-
- control_single_block_samples = [
- control_single_block_sample + block_sample
- for control_single_block_sample, block_sample in zip(
- control_single_block_samples, single_block_samples
- )
- ]
-
- # Regular Multi-ControlNets
- # load all ControlNets into memories
- else:
- for i, (image, mode, scale, controlnet) in enumerate(
- zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
- ):
- block_samples, single_block_samples = controlnet(
- hidden_states=hidden_states,
- controlnet_cond=image,
- controlnet_mode=mode[:, None],
- conditioning_scale=scale,
- timestep=timestep,
- guidance=guidance,
- pooled_projections=pooled_projections,
- encoder_hidden_states=encoder_hidden_states,
- txt_ids=txt_ids,
- img_ids=img_ids,
- joint_attention_kwargs=joint_attention_kwargs,
- return_dict=return_dict,
- )
- # merge samples
- if i == 0:
- control_block_samples = block_samples
- control_single_block_samples = single_block_samples
- else:
- if block_samples is not None and control_block_samples is not None:
- control_block_samples = [
- control_block_sample + block_sample
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
- ]
- if single_block_samples is not None and control_single_block_samples is not None:
- control_single_block_samples = [
- control_single_block_sample + block_sample
- for control_single_block_sample, block_sample in zip(
- control_single_block_samples, single_block_samples
- )
- ]
- return control_block_samples, control_single_block_samples
+class FluxMultiControlNetModel(FluxMultiControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
+ deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py
index 43b52a645a0d..0f7246c6c6d4 100644
--- a/src/diffusers/models/controlnet_sd3.py
+++ b/src/diffusers/models/controlnet_sd3.py
@@ -13,35 +13,21 @@
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ..models.attention import JointTransformerBlock
-from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
-from ..models.modeling_outputs import Transformer2DModelOutput
-from ..models.modeling_utils import ModelMixin
-from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
-from .controlnet import BaseOutput, zero_module
-from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
+from ..utils import deprecate, logging
+from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-@dataclass
-class SD3ControlNetOutput(BaseOutput):
- controlnet_block_samples: Tuple[torch.Tensor]
-
+class SD3ControlNetOutput(SD3ControlNetOutput):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
+ deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
-class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
- _supports_gradient_checkpointing = True
- @register_to_config
+class SD3ControlNetModel(SD3ControlNetModel):
def __init__(
self,
sample_size: int = 128,
@@ -57,366 +43,26 @@ def __init__(
pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0,
):
- super().__init__()
- default_out_channels = in_channels
- self.out_channels = out_channels if out_channels is not None else default_out_channels
- self.inner_dim = num_attention_heads * attention_head_dim
-
- self.pos_embed = PatchEmbed(
- height=sample_size,
- width=sample_size,
+ deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
+ deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
+ super().__init__(
+ sample_size=sample_size,
patch_size=patch_size,
in_channels=in_channels,
- embed_dim=self.inner_dim,
+ num_layers=num_layers,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ joint_attention_dim=joint_attention_dim,
+ caption_projection_dim=caption_projection_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ out_channels=out_channels,
pos_embed_max_size=pos_embed_max_size,
+ extra_conditioning_channels=extra_conditioning_channels,
)
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
- embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
- )
- self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
-
- # `attention_head_dim` is doubled to account for the mixing.
- # It needs to crafted when we get the actual checkpoints.
- self.transformer_blocks = nn.ModuleList(
- [
- JointTransformerBlock(
- dim=self.inner_dim,
- num_attention_heads=num_attention_heads,
- attention_head_dim=self.config.attention_head_dim,
- context_pre_only=False,
- )
- for i in range(num_layers)
- ]
- )
-
- # controlnet_blocks
- self.controlnet_blocks = nn.ModuleList([])
- for _ in range(len(self.transformer_blocks)):
- controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_blocks.append(controlnet_block)
- pos_embed_input = PatchEmbed(
- height=sample_size,
- width=sample_size,
- patch_size=patch_size,
- in_channels=in_channels + extra_conditioning_channels,
- embed_dim=self.inner_dim,
- pos_embed_type=None,
- )
- self.pos_embed_input = zero_module(pos_embed_input)
-
- self.gradient_checkpointing = False
-
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
- def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
- """
- Sets the attention processor to use [feed forward
- chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
-
- Parameters:
- chunk_size (`int`, *optional*):
- The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
- over each tensor of dim=`dim`.
- dim (`int`, *optional*, defaults to `0`):
- The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
- or dim=1 (sequence length).
- """
- if dim not in [0, 1]:
- raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
-
- # By default chunk size is 1
- chunk_size = chunk_size or 1
-
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
- if hasattr(module, "set_chunk_feed_forward"):
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
-
- for child in module.children():
- fn_recursive_feed_forward(child, chunk_size, dim)
-
- for module in self.children():
- fn_recursive_feed_forward(module, chunk_size, dim)
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedJointAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
- @classmethod
- def from_transformer(
- cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
- ):
- config = transformer.config
- config["num_layers"] = num_layers or config.num_layers
- config["extra_conditioning_channels"] = num_extra_conditioning_channels
- controlnet = cls(**config)
-
- if load_weights_from_transformer:
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
- controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
- controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
- controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
-
- controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
-
- return controlnet
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- controlnet_cond: torch.Tensor,
- conditioning_scale: float = 1.0,
- encoder_hidden_states: torch.FloatTensor = None,
- pooled_projections: torch.FloatTensor = None,
- timestep: torch.LongTensor = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
- """
- The [`SD3Transformer2DModel`] forward method.
-
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
- Input `hidden_states`.
- controlnet_cond (`torch.Tensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- conditioning_scale (`float`, defaults to `1.0`):
- The scale factor for ControlNet outputs.
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
- from the embeddings of input conditions.
- timestep ( `torch.LongTensor`):
- Used to indicate denoising step.
- joint_attention_kwargs (`dict`, *optional*):
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
- `self.processor` in
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
- tuple.
-
- Returns:
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
- `tuple` where the first element is the sample tensor.
- """
- if joint_attention_kwargs is not None:
- joint_attention_kwargs = joint_attention_kwargs.copy()
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
- else:
- lora_scale = 1.0
-
- if USE_PEFT_BACKEND:
- # weight the lora layers by setting `lora_scale` for each PEFT layer
- scale_lora_layers(self, lora_scale)
- else:
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
- logger.warning(
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
- )
-
- hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
- temb = self.time_text_embed(timestep, pooled_projections)
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
-
- # add
- hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
-
- block_res_samples = ()
-
- for block in self.transformer_blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- encoder_hidden_states,
- temb,
- **ckpt_kwargs,
- )
-
- else:
- encoder_hidden_states, hidden_states = block(
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
- )
-
- block_res_samples = block_res_samples + (hidden_states,)
-
- controlnet_block_res_samples = ()
- for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
- block_res_sample = controlnet_block(block_res_sample)
- controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
-
- # 6. scaling
- controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
-
- if USE_PEFT_BACKEND:
- # remove `lora_scale` from each PEFT layer
- unscale_lora_layers(self, lora_scale)
-
- if not return_dict:
- return (controlnet_block_res_samples,)
-
- return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
-
-
-class SD3MultiControlNetModel(ModelMixin):
- r"""
- `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
-
- This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
- compatible with `SD3ControlNetModel`.
-
- Args:
- controlnets (`List[SD3ControlNetModel]`):
- Provides additional conditioning to the unet during the denoising process. You must set multiple
- `SD3ControlNetModel` as a list.
- """
-
- def __init__(self, controlnets):
- super().__init__()
- self.nets = nn.ModuleList(controlnets)
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- controlnet_cond: List[torch.tensor],
- conditioning_scale: List[float],
- pooled_projections: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- timestep: torch.LongTensor = None,
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- return_dict: bool = True,
- ) -> Union[SD3ControlNetOutput, Tuple]:
- for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
- block_samples = controlnet(
- hidden_states=hidden_states,
- timestep=timestep,
- encoder_hidden_states=encoder_hidden_states,
- pooled_projections=pooled_projections,
- controlnet_cond=image,
- conditioning_scale=scale,
- joint_attention_kwargs=joint_attention_kwargs,
- return_dict=return_dict,
- )
- # merge samples
- if i == 0:
- control_block_samples = block_samples
- else:
- control_block_samples = [
- control_block_sample + block_sample
- for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
- ]
- control_block_samples = (tuple(control_block_samples),)
- return control_block_samples
+class SD3MultiControlNetModel(SD3MultiControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
+ deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py
index fa37e1f9e393..8fdaa21bef11 100644
--- a/src/diffusers/models/controlnet_sparsectrl.py
+++ b/src/diffusers/models/controlnet_sparsectrl.py
@@ -12,152 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-import torch
-from torch import nn
-from torch.nn import functional as F
+from typing import Optional, Tuple, Union
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..loaders import FromOriginalModelMixin
-from ..utils import BaseOutput, logging
-from .attention_processor import (
- ADDED_KV_ATTENTION_PROCESSORS,
- CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
- AttnAddedKVProcessor,
- AttnProcessor,
+from ..utils import deprecate, logging
+from .controlnets.controlnet_sparsectrl import ( # noqa
+ SparseControlNetConditioningEmbedding,
+ SparseControlNetModel,
+ SparseControlNetOutput,
+ zero_module,
)
-from .embeddings import TimestepEmbedding, Timesteps
-from .modeling_utils import ModelMixin
-from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
-from .unets.unet_2d_condition import UNet2DConditionModel
-from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-@dataclass
-class SparseControlNetOutput(BaseOutput):
- """
- The output of [`SparseControlNetModel`].
+class SparseControlNetOutput(SparseControlNetOutput):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
+ deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
- Args:
- down_block_res_samples (`tuple[torch.Tensor]`):
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
- used to condition the original UNet's downsampling activations.
- mid_down_block_re_sample (`torch.Tensor`):
- The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
- Output can be used to condition the original UNet's middle block activation.
- """
- down_block_res_samples: Tuple[torch.Tensor]
- mid_block_res_sample: torch.Tensor
-
-
-class SparseControlNetConditioningEmbedding(nn.Module):
- def __init__(
- self,
- conditioning_embedding_channels: int,
- conditioning_channels: int = 3,
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
- ):
- super().__init__()
-
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
- self.blocks = nn.ModuleList([])
-
- for i in range(len(block_out_channels) - 1):
- channel_in = block_out_channels[i]
- channel_out = block_out_channels[i + 1]
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
-
- self.conv_out = zero_module(
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
+ deprecate(
+ "diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
)
+ super().__init__(*args, **kwargs)
- def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
- embedding = self.conv_in(conditioning)
- embedding = F.silu(embedding)
- for block in self.blocks:
- embedding = block(embedding)
- embedding = F.silu(embedding)
-
- embedding = self.conv_out(embedding)
- return embedding
-
-
-class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
- """
- A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
- Models](https://arxiv.org/abs/2311.16933).
-
- Args:
- in_channels (`int`, defaults to 4):
- The number of channels in the input sample.
- conditioning_channels (`int`, defaults to 4):
- The number of input channels in the controlnet conditional embedding module. If
- `concat_condition_embedding` is True, the value provided here is incremented by 1.
- flip_sin_to_cos (`bool`, defaults to `True`):
- Whether to flip the sin to cos in the time embedding.
- freq_shift (`int`, defaults to 0):
- The frequency shift to apply to the time embedding.
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
- The tuple of downsample blocks to use.
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
- The tuple of output channels for each block.
- layers_per_block (`int`, defaults to 2):
- The number of layers per block.
- downsample_padding (`int`, defaults to 1):
- The padding to use for the downsampling convolution.
- mid_block_scale_factor (`float`, defaults to 1):
- The scale factor to use for the mid block.
- act_fn (`str`, defaults to "silu"):
- The activation function to use.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
- in post-processing.
- norm_eps (`float`, defaults to 1e-5):
- The epsilon to use for the normalization.
- cross_attention_dim (`int`, defaults to 1280):
- The dimension of the cross attention features.
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
- transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
- The number of transformer layers to use in each layer in the middle block.
- attention_head_dim (`int` or `Tuple[int]`, defaults to 8):
- The dimension of the attention heads.
- num_attention_heads (`int` or `Tuple[int]`, *optional*):
- The number of heads to use for multi-head attention.
- use_linear_projection (`bool`, defaults to `False`):
- upcast_attention (`bool`, defaults to `False`):
- resnet_time_scale_shift (`str`, defaults to `"default"`):
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
- conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
- The tuple of output channel for each block in the `conditioning_embedding` layer.
- global_pool_conditions (`bool`, defaults to `False`):
- TODO(Patrick) - unused parameter
- controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
- motion_max_seq_length (`int`, defaults to `32`):
- The maximum sequence length to use in the motion module.
- motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`):
- The number of heads to use in each attention layer of the motion module.
- concat_conditioning_mask (`bool`, defaults to `True`):
- use_simplified_condition_embedding (`bool`, defaults to `True`):
- """
-
- _supports_gradient_checkpointing = True
-
- @register_to_config
+class SparseControlNetModel(SparseControlNetModel):
def __init__(
self,
in_channels: int = 4,
@@ -195,594 +81,36 @@ def __init__(
concat_conditioning_mask: bool = True,
use_simplified_condition_embedding: bool = True,
):
- super().__init__()
- self.use_simplified_condition_embedding = use_simplified_condition_embedding
-
- # If `num_attention_heads` is not defined (which is the case for most models)
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
- # which is why we correct for the naming here.
- num_attention_heads = num_attention_heads or attention_head_dim
-
- # Check inputs
- if len(block_out_channels) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
- )
-
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
- raise ValueError(
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
- )
-
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
- if isinstance(temporal_transformer_layers_per_block, int):
- temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
-
- # input
- conv_in_kernel = 3
- conv_in_padding = (conv_in_kernel - 1) // 2
- self.conv_in = nn.Conv2d(
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
- )
-
- if concat_conditioning_mask:
- conditioning_channels = conditioning_channels + 1
-
- self.concat_conditioning_mask = concat_conditioning_mask
-
- # control net conditioning embedding
- if use_simplified_condition_embedding:
- self.controlnet_cond_embedding = zero_module(
- nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
- )
- else:
- self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
- conditioning_embedding_channels=block_out_channels[0],
- block_out_channels=conditioning_embedding_out_channels,
- conditioning_channels=conditioning_channels,
- )
-
- # time
- time_embed_dim = block_out_channels[0] * 4
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
- timestep_input_dim = block_out_channels[0]
-
- self.time_embedding = TimestepEmbedding(
- timestep_input_dim,
- time_embed_dim,
+ deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
+ deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
+ super().__init__(
+ in_channels=in_channels,
+ conditioning_channels=conditioning_channels,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ down_block_types=down_block_types,
+ only_cross_attention=only_cross_attention,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ downsample_padding=downsample_padding,
+ mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
- )
-
- self.down_blocks = nn.ModuleList([])
- self.controlnet_down_blocks = nn.ModuleList([])
-
- if isinstance(cross_attention_dim, int):
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
-
- if isinstance(only_cross_attention, bool):
- only_cross_attention = [only_cross_attention] * len(down_block_types)
-
- if isinstance(attention_head_dim, int):
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
-
- if isinstance(num_attention_heads, int):
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
-
- if isinstance(motion_num_attention_heads, int):
- motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
-
- # down
- output_channel = block_out_channels[0]
-
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- for i, down_block_type in enumerate(down_block_types):
- input_channel = output_channel
- output_channel = block_out_channels[i]
- is_final_block = i == len(block_out_channels) - 1
-
- if down_block_type == "CrossAttnDownBlockMotion":
- down_block = CrossAttnDownBlockMotion(
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=time_embed_dim,
- dropout=0,
- num_layers=layers_per_block,
- transformer_layers_per_block=transformer_layers_per_block[i],
- resnet_eps=norm_eps,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- resnet_pre_norm=True,
- num_attention_heads=num_attention_heads[i],
- cross_attention_dim=cross_attention_dim[i],
- add_downsample=not is_final_block,
- dual_cross_attention=False,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention[i],
- upcast_attention=upcast_attention,
- temporal_num_attention_heads=motion_num_attention_heads[i],
- temporal_max_seq_length=motion_max_seq_length,
- temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
- temporal_double_self_attention=False,
- )
- elif down_block_type == "DownBlockMotion":
- down_block = DownBlockMotion(
- in_channels=input_channel,
- out_channels=output_channel,
- temb_channels=time_embed_dim,
- dropout=0,
- num_layers=layers_per_block,
- resnet_eps=norm_eps,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- resnet_pre_norm=True,
- add_downsample=not is_final_block,
- temporal_num_attention_heads=motion_num_attention_heads[i],
- temporal_max_seq_length=motion_max_seq_length,
- temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
- temporal_double_self_attention=False,
- )
- else:
- raise ValueError(
- "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
- )
-
- self.down_blocks.append(down_block)
-
- for _ in range(layers_per_block):
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- if not is_final_block:
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_down_blocks.append(controlnet_block)
-
- # mid
- mid_block_channels = block_out_channels[-1]
-
- controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1)
- controlnet_block = zero_module(controlnet_block)
- self.controlnet_mid_block = controlnet_block
-
- if transformer_layers_per_mid_block is None:
- transformer_layers_per_mid_block = (
- transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
- )
-
- self.mid_block = UNetMidBlock2DCrossAttn(
- in_channels=mid_block_channels,
- temb_channels=time_embed_dim,
- dropout=0,
- num_layers=1,
- transformer_layers_per_block=transformer_layers_per_mid_block,
- resnet_eps=norm_eps,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resnet_act_fn=act_fn,
- resnet_groups=norm_num_groups,
- resnet_pre_norm=True,
- num_attention_heads=num_attention_heads[-1],
- output_scale_factor=mid_block_scale_factor,
- cross_attention_dim=cross_attention_dim[-1],
- dual_cross_attention=False,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ cross_attention_dim=cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ transformer_layers_per_mid_block=transformer_layers_per_mid_block,
+ temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
- attention_type="default",
- )
-
- @classmethod
- def from_unet(
- cls,
- unet: UNet2DConditionModel,
- controlnet_conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
- load_weights_from_unet: bool = True,
- conditioning_channels: int = 3,
- ) -> "SparseControlNetModel":
- r"""
- Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`].
-
- Parameters:
- unet (`UNet2DConditionModel`):
- The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also
- copied where applicable.
- """
- transformer_layers_per_block = (
- unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
- )
- down_block_types = unet.config.down_block_types
-
- for i in range(len(down_block_types)):
- if "CrossAttn" in down_block_types[i]:
- down_block_types[i] = "CrossAttnDownBlockMotion"
- elif "Down" in down_block_types[i]:
- down_block_types[i] = "DownBlockMotion"
- else:
- raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block")
-
- controlnet = cls(
- in_channels=unet.config.in_channels,
- conditioning_channels=conditioning_channels,
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
- freq_shift=unet.config.freq_shift,
- down_block_types=unet.config.down_block_types,
- only_cross_attention=unet.config.only_cross_attention,
- block_out_channels=unet.config.block_out_channels,
- layers_per_block=unet.config.layers_per_block,
- downsample_padding=unet.config.downsample_padding,
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
- act_fn=unet.config.act_fn,
- norm_num_groups=unet.config.norm_num_groups,
- norm_eps=unet.config.norm_eps,
- cross_attention_dim=unet.config.cross_attention_dim,
- transformer_layers_per_block=transformer_layers_per_block,
- attention_head_dim=unet.config.attention_head_dim,
- num_attention_heads=unet.config.num_attention_heads,
- use_linear_projection=unet.config.use_linear_projection,
- upcast_attention=unet.config.upcast_attention,
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ resnet_time_scale_shift=resnet_time_scale_shift,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ global_pool_conditions=global_pool_conditions,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ motion_max_seq_length=motion_max_seq_length,
+ motion_num_attention_heads=motion_num_attention_heads,
+ concat_conditioning_mask=concat_conditioning_mask,
+ use_simplified_condition_embedding=use_simplified_condition_embedding,
)
-
- if load_weights_from_unet:
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False)
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False)
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False)
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
-
- return controlnet
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
- def set_default_attn_processor(self):
- """
- Disables custom attention processors and sets the default attention implementation.
- """
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnAddedKVProcessor()
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnProcessor()
- else:
- raise ValueError(
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
- )
-
- self.set_attn_processor(processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
- r"""
- Enable sliced attention computation.
-
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
-
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- sliceable_head_dims = []
-
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
- if hasattr(module, "set_attention_slice"):
- sliceable_head_dims.append(module.sliceable_head_dim)
-
- for child in module.children():
- fn_recursive_retrieve_sliceable_dims(child)
-
- # retrieve number of attention layers
- for module in self.children():
- fn_recursive_retrieve_sliceable_dims(module)
-
- num_sliceable_layers = len(sliceable_head_dims)
-
- if slice_size == "auto":
- # half the attention head size is usually a good trade-off between
- # speed and memory
- slice_size = [dim // 2 for dim in sliceable_head_dims]
- elif slice_size == "max":
- # make smallest slice possible
- slice_size = num_sliceable_layers * [1]
-
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
-
- if len(slice_size) != len(sliceable_head_dims):
- raise ValueError(
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
- )
-
- for i in range(len(slice_size)):
- size = slice_size[i]
- dim = sliceable_head_dims[i]
- if size is not None and size > dim:
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
-
- # Recursively walk through all the children.
- # Any children which exposes the set_attention_slice method
- # gets the message
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
- if hasattr(module, "set_attention_slice"):
- module.set_attention_slice(slice_size.pop())
-
- for child in module.children():
- fn_recursive_set_attention_slice(child, slice_size)
-
- reversed_slice_size = list(reversed(slice_size))
- for module in self.children():
- fn_recursive_set_attention_slice(module, reversed_slice_size)
-
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
- module.gradient_checkpointing = value
-
- def forward(
- self,
- sample: torch.Tensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- controlnet_cond: torch.Tensor,
- conditioning_scale: float = 1.0,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- conditioning_mask: Optional[torch.Tensor] = None,
- guess_mode: bool = False,
- return_dict: bool = True,
- ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
- """
- The [`SparseControlNetModel`] forward method.
-
- Args:
- sample (`torch.Tensor`):
- The noisy input tensor.
- timestep (`Union[torch.Tensor, float, int]`):
- The number of timesteps to denoise an input.
- encoder_hidden_states (`torch.Tensor`):
- The encoder hidden states.
- controlnet_cond (`torch.Tensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- conditioning_scale (`float`, defaults to `1.0`):
- The scale factor for ControlNet outputs.
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
- embeddings.
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
- negative values to the attention scores corresponding to "discard" tokens.
- added_cond_kwargs (`dict`):
- Additional conditions for the Stable Diffusion XL UNet.
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
- guess_mode (`bool`, defaults to `False`):
- In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
- return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
- Returns:
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
- returned where the first element is the sample tensor.
- """
- sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape
- sample = torch.zeros_like(sample)
-
- # check channel order
- channel_order = self.config.controlnet_conditioning_channel_order
-
- if channel_order == "rgb":
- # in rgb order by default
- ...
- elif channel_order == "bgr":
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
- else:
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
-
- # prepare attention_mask
- if attention_mask is not None:
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # 1. time
- timesteps = timestep
- if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
- else:
- dtype = torch.int32 if is_mps else torch.int64
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
-
- t_emb = self.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=sample.dtype)
-
- emb = self.time_embedding(t_emb, timestep_cond)
- emb = emb.repeat_interleave(sample_num_frames, dim=0)
-
- # 2. pre-process
- batch_size, channels, num_frames, height, width = sample.shape
-
- sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
- sample = self.conv_in(sample)
-
- batch_frames, channels, height, width = sample.shape
- sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width)
-
- if self.concat_conditioning_mask:
- controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
-
- batch_size, channels, num_frames, height, width = controlnet_cond.shape
- controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape(
- batch_size * num_frames, channels, height, width
- )
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
- batch_frames, channels, height, width = controlnet_cond.shape
- controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width)
-
- sample = sample + controlnet_cond
-
- batch_size, num_frames, channels, height, width = sample.shape
- sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width)
-
- # 3. down
- down_block_res_samples = (sample,)
- for downsample_block in self.down_blocks:
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
- sample, res_samples = downsample_block(
- hidden_states=sample,
- temb=emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- num_frames=num_frames,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
-
- down_block_res_samples += res_samples
-
- # 4. mid
- if self.mid_block is not None:
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample = self.mid_block(sample, emb)
-
- # 5. Control net blocks
- controlnet_down_block_res_samples = ()
-
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
- down_block_res_sample = controlnet_block(down_block_res_sample)
- controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
-
- down_block_res_samples = controlnet_down_block_res_samples
- mid_block_res_sample = self.controlnet_mid_block(sample)
-
- # 6. scaling
- if guess_mode and not self.config.global_pool_conditions:
- scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
- scales = scales * conditioning_scale
- down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
- mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
- else:
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
-
- if self.config.global_pool_conditions:
- down_block_res_samples = [
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
- ]
- mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
-
- if not return_dict:
- return (down_block_res_samples, mid_block_res_sample)
-
- return SparseControlNetOutput(
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
- )
-
-
-# Copied from diffusers.models.controlnet.zero_module
-def zero_module(module: nn.Module) -> nn.Module:
- for p in module.parameters():
- nn.init.zeros_(p)
- return module
diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py
new file mode 100644
index 000000000000..1dd92e51a44c
--- /dev/null
+++ b/src/diffusers/models/controlnets/__init__.py
@@ -0,0 +1,24 @@
+from ...utils import is_flax_available, is_torch_available
+
+
+if is_torch_available():
+ from .controlnet import ControlNetModel, ControlNetOutput
+ from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
+ from .controlnet_hunyuan import (
+ HunyuanControlNetOutput,
+ HunyuanDiT2DControlNetModel,
+ HunyuanDiT2DMultiControlNetModel,
+ )
+ from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
+ from .controlnet_sparsectrl import (
+ SparseControlNetConditioningEmbedding,
+ SparseControlNetModel,
+ SparseControlNetOutput,
+ )
+ from .controlnet_union import ControlNetUnionModel
+ from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
+ from .multicontrolnet import MultiControlNetModel
+ from .multicontrolnet_union import MultiControlNetUnionModel
+
+if is_flax_available():
+ from .controlnet_flax import FlaxControlNetModel
diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py
new file mode 100644
index 000000000000..7a6ca886caed
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet.py
@@ -0,0 +1,867 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders.single_file_model import FromOriginalModelMixin
+from ...utils import BaseOutput, logging
+from ..attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from ..modeling_utils import ModelMixin
+from ..unets.unet_2d_blocks import (
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
+from ..unets.unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetOutput(BaseOutput):
+ """
+ The output of [`ControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ """
+ A ControlNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter.
+ addition_embed_type_num_heads (`int`, defaults to 64):
+ The number of heads to use for the `TextTimeEmbedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ num_layers=0,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
+ ):
+ r"""
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ controlnet = cls(
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=unet.config.in_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ mid_block_type=unet.config.mid_block_type,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if controlnet.class_embedding:
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ if hasattr(controlnet, "add_embedding"):
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
+ """
+ The [`ControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned,
+ otherwise a tuple is returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ else:
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ sample = sample + controlnet_cond
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # 5. Control net blocks
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnets/controlnet_flax.py
similarity index 98%
rename from src/diffusers/models/controlnet_flax.py
rename to src/diffusers/models/controlnets/controlnet_flax.py
index 0540850a9e61..ab8d9b5f8cbb 100644
--- a/src/diffusers/models/controlnet_flax.py
+++ b/src/diffusers/models/controlnets/controlnet_flax.py
@@ -19,11 +19,11 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
-from ..configuration_utils import ConfigMixin, flax_register_to_config
-from ..utils import BaseOutput
-from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
-from .modeling_flax_utils import FlaxModelMixin
-from .unets.unet_2d_blocks_flax import (
+from ...configuration_utils import ConfigMixin, flax_register_to_config
+from ...utils import BaseOutput
+from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
+from ..modeling_flax_utils import FlaxModelMixin
+from ..unets.unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxDownBlock2D,
FlaxUNetMidBlock2DCrossAttn,
diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py
new file mode 100644
index 000000000000..51c34b7fe965
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_flux.py
@@ -0,0 +1,508 @@
+# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...models.attention_processor import AttentionProcessor
+from ...models.modeling_utils import ModelMixin
+from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
+from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
+from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class FluxControlNetOutput(BaseOutput):
+ controlnet_block_samples: Tuple[torch.Tensor]
+ controlnet_single_block_samples: Tuple[torch.Tensor]
+
+
+class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ guidance_embeds: bool = False,
+ axes_dims_rope: List[int] = [16, 56, 56],
+ num_mode: int = None,
+ conditioning_embedding_channels: int = None,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
+ text_time_guidance_cls = (
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
+ )
+ self.time_text_embed = text_time_guidance_cls(
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
+ )
+
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ FluxTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ FluxSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for i in range(num_single_layers)
+ ]
+ )
+
+ # controlnet_blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
+
+ self.controlnet_single_blocks = nn.ModuleList([])
+ for _ in range(len(self.single_transformer_blocks)):
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
+
+ self.union = num_mode is not None
+ if self.union:
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
+
+ if conditioning_embedding_channels is not None:
+ self.input_hint_block = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
+ )
+ self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
+ else:
+ self.input_hint_block = None
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self):
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ @classmethod
+ def from_transformer(
+ cls,
+ transformer,
+ num_layers: int = 4,
+ num_single_layers: int = 10,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ load_weights_from_transformer=True,
+ ):
+ config = dict(transformer.config)
+ config["num_layers"] = num_layers
+ config["num_single_layers"] = num_single_layers
+ config["attention_head_dim"] = attention_head_dim
+ config["num_attention_heads"] = num_attention_heads
+
+ controlnet = cls.from_config(config)
+
+ if load_weights_from_transformer:
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
+ controlnet.single_transformer_blocks.load_state_dict(
+ transformer.single_transformer_blocks.state_dict(), strict=False
+ )
+
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
+
+ return controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ controlnet_mode: torch.Tensor = None,
+ conditioning_scale: float = 1.0,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ controlnet_mode (`torch.Tensor`):
+ The mode tensor of shape `(batch_size, 1)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ if self.input_hint_block is not None:
+ controlnet_cond = self.input_hint_block(controlnet_cond)
+ batch_size, channels, height_pw, width_pw = controlnet_cond.shape
+ height = height_pw // self.config.patch_size
+ width = width_pw // self.config.patch_size
+ controlnet_cond = controlnet_cond.reshape(
+ batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
+ )
+ controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
+ controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
+ # add
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+ else:
+ guidance = None
+ temb = (
+ self.time_text_embed(timestep, pooled_projections)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, pooled_projections)
+ )
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if self.union:
+ # union mode
+ if controlnet_mode is None:
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
+ # union mode emb
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ block_samples = ()
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ block_samples = block_samples + (hidden_states,)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ single_block_samples = ()
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
+
+ # controlnet block
+ controlnet_block_samples = ()
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
+ block_sample = controlnet_block(block_sample)
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
+
+ controlnet_single_block_samples = ()
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
+ single_block_sample = controlnet_block(single_block_sample)
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
+
+ # scaling
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
+
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
+ controlnet_single_block_samples = (
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
+ )
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (controlnet_block_samples, controlnet_single_block_samples)
+
+ return FluxControlNetOutput(
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ )
+
+
+class FluxMultiControlNetModel(ModelMixin):
+ r"""
+ `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
+
+ This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
+ compatible with `FluxControlNetModel`.
+
+ Args:
+ controlnets (`List[FluxControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `FluxControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: List[torch.tensor],
+ controlnet_mode: List[torch.tensor],
+ conditioning_scale: List[float],
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[FluxControlNetOutput, Tuple]:
+ # ControlNet-Union with multiple conditions
+ # only load one ControlNet for saving memories
+ if len(self.nets) == 1 and self.nets[0].union:
+ controlnet = self.nets[0]
+
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
+ block_samples, single_block_samples = controlnet(
+ hidden_states=hidden_states,
+ controlnet_cond=image,
+ controlnet_mode=mode[:, None],
+ conditioning_scale=scale,
+ timestep=timestep,
+ guidance=guidance,
+ pooled_projections=pooled_projections,
+ encoder_hidden_states=encoder_hidden_states,
+ txt_ids=txt_ids,
+ img_ids=img_ids,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ control_single_block_samples = single_block_samples
+ else:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+
+ control_single_block_samples = [
+ control_single_block_sample + block_sample
+ for control_single_block_sample, block_sample in zip(
+ control_single_block_samples, single_block_samples
+ )
+ ]
+
+ # Regular Multi-ControlNets
+ # load all ControlNets into memories
+ else:
+ for i, (image, mode, scale, controlnet) in enumerate(
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
+ ):
+ block_samples, single_block_samples = controlnet(
+ hidden_states=hidden_states,
+ controlnet_cond=image,
+ controlnet_mode=mode[:, None],
+ conditioning_scale=scale,
+ timestep=timestep,
+ guidance=guidance,
+ pooled_projections=pooled_projections,
+ encoder_hidden_states=encoder_hidden_states,
+ txt_ids=txt_ids,
+ img_ids=img_ids,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ control_single_block_samples = single_block_samples
+ else:
+ if block_samples is not None and control_block_samples is not None:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+ if single_block_samples is not None and control_single_block_samples is not None:
+ control_single_block_samples = [
+ control_single_block_sample + block_sample
+ for control_single_block_sample, block_sample in zip(
+ control_single_block_samples, single_block_samples
+ )
+ ]
+
+ return control_block_samples, control_single_block_samples
diff --git a/src/diffusers/models/controlnet_hunyuan.py b/src/diffusers/models/controlnets/controlnet_hunyuan.py
similarity index 98%
rename from src/diffusers/models/controlnet_hunyuan.py
rename to src/diffusers/models/controlnets/controlnet_hunyuan.py
index 4277d81d1cb9..fade44def4cd 100644
--- a/src/diffusers/models/controlnet_hunyuan.py
+++ b/src/diffusers/models/controlnets/controlnet_hunyuan.py
@@ -17,17 +17,17 @@
import torch
from torch import nn
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import logging
-from .attention_processor import AttentionProcessor
-from .controlnet import BaseOutput, Tuple, zero_module
-from .embeddings import (
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import BaseOutput, logging
+from ..attention_processor import AttentionProcessor
+from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
PixArtAlphaTextProjection,
)
-from .modeling_utils import ModelMixin
-from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
+from ..modeling_utils import ModelMixin
+from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
+from .controlnet import Tuple, zero_module
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py
new file mode 100644
index 000000000000..91ce76fe75a9
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_sd3.py
@@ -0,0 +1,513 @@
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import JointTransformerBlock
+from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
+from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
+from .controlnet import BaseOutput, zero_module
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class SD3ControlNetOutput(BaseOutput):
+ controlnet_block_samples: Tuple[torch.Tensor]
+
+
+class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ r"""
+ ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
+
+ Parameters:
+ sample_size (`int`, defaults to `128`):
+ The width/height of the latents. This is fixed during training since it is used to learn a number of
+ position embeddings.
+ patch_size (`int`, defaults to `2`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `16`):
+ The number of latent channels in the input.
+ num_layers (`int`, defaults to `18`):
+ The number of layers of transformer blocks to use.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ num_attention_heads (`int`, defaults to `18`):
+ The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, defaults to `4096`):
+ The embedding dimension to use for joint text-image attention.
+ caption_projection_dim (`int`, defaults to `1152`):
+ The embedding dimension of caption embeddings.
+ pooled_projection_dim (`int`, defaults to `2048`):
+ The embedding dimension of pooled text projections.
+ out_channels (`int`, defaults to `16`):
+ The number of latent channels in the output.
+ pos_embed_max_size (`int`, defaults to `96`):
+ The maximum latent height/width of positional embeddings.
+ extra_conditioning_channels (`int`, defaults to `0`):
+ The number of extra channels to use for conditioning for patch embedding.
+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
+ The number of dual-stream transformer blocks to use.
+ qk_norm (`str`, *optional*, defaults to `None`):
+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
+ pos_embed_type (`str`, defaults to `"sincos"`):
+ The type of positional embedding to use. Choose between `"sincos"` and `None`.
+ use_pos_embed (`bool`, defaults to `True`):
+ Whether to use positional embeddings.
+ force_zeros_for_pooled_projection (`bool`, defaults to `True`):
+ Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
+ config value of the ControlNet model.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 128,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ num_layers: int = 18,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 18,
+ joint_attention_dim: int = 4096,
+ caption_projection_dim: int = 1152,
+ pooled_projection_dim: int = 2048,
+ out_channels: int = 16,
+ pos_embed_max_size: int = 96,
+ extra_conditioning_channels: int = 0,
+ dual_attention_layers: Tuple[int, ...] = (),
+ qk_norm: Optional[str] = None,
+ pos_embed_type: Optional[str] = "sincos",
+ use_pos_embed: bool = True,
+ force_zeros_for_pooled_projection: bool = True,
+ ):
+ super().__init__()
+ default_out_channels = in_channels
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ if use_pos_embed:
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_max_size=pos_embed_max_size,
+ pos_embed_type=pos_embed_type,
+ )
+ else:
+ self.pos_embed = None
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
+ )
+ if joint_attention_dim is not None:
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
+
+ # `attention_head_dim` is doubled to account for the mixing.
+ # It needs to crafted when we get the actual checkpoints.
+ self.transformer_blocks = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ context_pre_only=False,
+ qk_norm=qk_norm,
+ use_dual_attention=True if i in dual_attention_layers else False,
+ )
+ for i in range(num_layers)
+ ]
+ )
+ else:
+ self.context_embedder = None
+ self.transformer_blocks = nn.ModuleList(
+ [
+ SD3SingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # controlnet_blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_blocks.append(controlnet_block)
+ pos_embed_input = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels + extra_conditioning_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_type=None,
+ )
+ self.pos_embed_input = zero_module(pos_embed_input)
+
+ self.gradient_checkpointing = False
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
+ # we should have handled this in conversion script
+ def _get_pos_embed_from_transformer(self, transformer):
+ pos_embed = PatchEmbed(
+ height=transformer.config.sample_size,
+ width=transformer.config.sample_size,
+ patch_size=transformer.config.patch_size,
+ in_channels=transformer.config.in_channels,
+ embed_dim=transformer.inner_dim,
+ pos_embed_max_size=transformer.config.pos_embed_max_size,
+ )
+ pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True)
+ return pos_embed
+
+ @classmethod
+ def from_transformer(
+ cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
+ ):
+ config = transformer.config
+ config["num_layers"] = num_layers or config.num_layers
+ config["extra_conditioning_channels"] = num_extra_conditioning_channels
+ controlnet = cls.from_config(config)
+
+ if load_weights_from_transformer:
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
+
+ controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
+
+ return controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`SD3Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ if self.pos_embed is not None and hidden_states.ndim != 4:
+ raise ValueError("hidden_states must be 4D when pos_embed is used")
+
+ # SD3.5 8b controlnet does not have a `pos_embed`,
+ # it use the `pos_embed` from the transformer to process input before passing to controlnet
+ elif self.pos_embed is None and hidden_states.ndim != 3:
+ raise ValueError("hidden_states must be 3D when pos_embed is not used")
+
+ if self.context_embedder is not None and encoder_hidden_states is None:
+ raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
+ # SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
+ elif self.context_embedder is None and encoder_hidden_states is not None:
+ raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
+
+ if self.pos_embed is not None:
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
+
+ temb = self.time_text_embed(timestep, pooled_projections)
+
+ if self.context_embedder is not None:
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ # add
+ hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
+
+ block_res_samples = ()
+
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ if self.context_embedder is not None:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ )
+ else:
+ # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
+ hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb)
+
+ else:
+ if self.context_embedder is not None:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ )
+ else:
+ # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
+ hidden_states = block(hidden_states, temb)
+
+ block_res_samples = block_res_samples + (hidden_states,)
+
+ controlnet_block_res_samples = ()
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
+ block_res_sample = controlnet_block(block_res_sample)
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
+
+ # 6. scaling
+ controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (controlnet_block_res_samples,)
+
+ return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
+
+
+class SD3MultiControlNetModel(ModelMixin):
+ r"""
+ `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
+
+ This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
+ compatible with `SD3ControlNetModel`.
+
+ Args:
+ controlnets (`List[SD3ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `SD3ControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ pooled_projections: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[SD3ControlNetOutput, Tuple]:
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
+ block_samples = controlnet(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ pooled_projections=pooled_projections,
+ controlnet_cond=image,
+ conditioning_scale=scale,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ else:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
+ ]
+ control_block_samples = (tuple(control_block_samples),)
+
+ return control_block_samples
diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py
new file mode 100644
index 000000000000..25348ce606d6
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py
@@ -0,0 +1,785 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import BaseOutput, logging
+from ..attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from ..embeddings import TimestepEmbedding, Timesteps
+from ..modeling_utils import ModelMixin
+from ..unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
+from ..unets.unet_2d_condition import UNet2DConditionModel
+from ..unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class SparseControlNetOutput(BaseOutput):
+ """
+ The output of [`SparseControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class SparseControlNetConditioningEmbedding(nn.Module):
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+ return embedding
+
+
+class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ """
+ A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
+ Models](https://arxiv.org/abs/2311.16933).
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ conditioning_channels (`int`, defaults to 4):
+ The number of input channels in the controlnet conditional embedding module. If
+ `concat_condition_embedding` is True, the value provided here is incremented by 1.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer layers to use in each layer in the middle block.
+ attention_head_dim (`int` or `Tuple[int]`, defaults to 8):
+ The dimension of the attention heads.
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
+ The number of heads to use for multi-head attention.
+ use_linear_projection (`bool`, defaults to `False`):
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter
+ controlnet_conditioning_channel_order (`str`, defaults to `rgb`):
+ motion_max_seq_length (`int`, defaults to `32`):
+ The maximum sequence length to use in the motion module.
+ motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`):
+ The number of heads to use in each attention layer of the motion module.
+ concat_conditioning_mask (`bool`, defaults to `True`):
+ use_simplified_condition_embedding (`bool`, defaults to `True`):
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 4,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "DownBlockMotion",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 768,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
+ temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ controlnet_conditioning_channel_order: str = "rgb",
+ motion_max_seq_length: int = 32,
+ motion_num_attention_heads: int = 8,
+ concat_conditioning_mask: bool = True,
+ use_simplified_condition_embedding: bool = True,
+ ):
+ super().__init__()
+ self.use_simplified_condition_embedding = use_simplified_condition_embedding
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+ if isinstance(temporal_transformer_layers_per_block, int):
+ temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ if concat_conditioning_mask:
+ conditioning_channels = conditioning_channels + 1
+
+ self.concat_conditioning_mask = concat_conditioning_mask
+
+ # control net conditioning embedding
+ if use_simplified_condition_embedding:
+ self.controlnet_cond_embedding = zero_module(
+ nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+ )
+ else:
+ self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(motion_num_attention_heads, int):
+ motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ if down_block_type == "CrossAttnDownBlockMotion":
+ down_block = CrossAttnDownBlockMotion(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ dropout=0,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ resnet_eps=norm_eps,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ resnet_pre_norm=True,
+ num_attention_heads=num_attention_heads[i],
+ cross_attention_dim=cross_attention_dim[i],
+ add_downsample=not is_final_block,
+ dual_cross_attention=False,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ temporal_num_attention_heads=motion_num_attention_heads[i],
+ temporal_max_seq_length=motion_max_seq_length,
+ temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
+ temporal_double_self_attention=False,
+ )
+ elif down_block_type == "DownBlockMotion":
+ down_block = DownBlockMotion(
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ dropout=0,
+ num_layers=layers_per_block,
+ resnet_eps=norm_eps,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ resnet_pre_norm=True,
+ add_downsample=not is_final_block,
+ temporal_num_attention_heads=motion_num_attention_heads[i],
+ temporal_max_seq_length=motion_max_seq_length,
+ temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
+ temporal_double_self_attention=False,
+ )
+ else:
+ raise ValueError(
+ "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`"
+ )
+
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channels = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ if transformer_layers_per_mid_block is None:
+ transformer_layers_per_mid_block = (
+ transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1
+ )
+
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=mid_block_channels,
+ temb_channels=time_embed_dim,
+ dropout=0,
+ num_layers=1,
+ transformer_layers_per_block=transformer_layers_per_mid_block,
+ resnet_eps=norm_eps,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ resnet_pre_norm=True,
+ num_attention_heads=num_attention_heads[-1],
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ dual_cross_attention=False,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type="default",
+ )
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
+ ) -> "SparseControlNetModel":
+ r"""
+ Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also
+ copied where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ down_block_types = unet.config.down_block_types
+
+ for i in range(len(down_block_types)):
+ if "CrossAttn" in down_block_types[i]:
+ down_block_types[i] = "CrossAttnDownBlockMotion"
+ elif "Down" in down_block_types[i]:
+ down_block_types[i] = "DownBlockMotion"
+ else:
+ raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block")
+
+ controlnet = cls(
+ in_channels=unet.config.in_channels,
+ conditioning_channels=conditioning_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False)
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False)
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False)
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ conditioning_mask: Optional[torch.Tensor] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
+ """
+ The [`SparseControlNetModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape
+ sample = torch.zeros_like(sample)
+
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ else:
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
+
+ # 2. pre-process
+ batch_size, channels, num_frames, height, width = sample.shape
+
+ sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
+ sample = self.conv_in(sample)
+
+ batch_frames, channels, height, width = sample.shape
+ sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width)
+
+ if self.concat_conditioning_mask:
+ controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
+
+ batch_size, channels, num_frames, height, width = controlnet_cond.shape
+ controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape(
+ batch_size * num_frames, channels, height, width
+ )
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ batch_frames, channels, height, width = controlnet_cond.shape
+ controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width)
+
+ sample = sample + controlnet_cond
+
+ batch_size, num_frames, channels, height, width = sample.shape
+ sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # 5. Control net blocks
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ else:
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return SparseControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+# Copied from diffusers.models.controlnets.controlnet.zero_module
+def zero_module(module: nn.Module) -> nn.Module:
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py
new file mode 100644
index 000000000000..26cb86718a21
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_union.py
@@ -0,0 +1,841 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders.single_file_model import FromOriginalModelMixin
+from ...utils import logging
+from ..attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from ..modeling_utils import ModelMixin
+from ..unets.unet_2d_blocks import (
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
+from ..unets.unet_2d_condition import UNet2DConditionModel
+from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class QuickGELU(nn.Module):
+ """
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return input * torch.sigmoid(1.702 * input)
+
+
+class ResidualAttentionMlp(nn.Module):
+ def __init__(self, d_model: int):
+ super().__init__()
+ self.c_fc = nn.Linear(d_model, d_model * 4)
+ self.gelu = QuickGELU()
+ self.c_proj = nn.Linear(d_model * 4, d_model)
+
+ def forward(self, x: torch.Tensor):
+ x = self.c_fc(x)
+ x = self.gelu(x)
+ x = self.c_proj(x)
+ return x
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = nn.LayerNorm(d_model)
+ self.mlp = ResidualAttentionMlp(d_model)
+ self.ln_2 = nn.LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ """
+ A ControlNetUnion model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(48, 96, 192, 384)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (48, 96, 192, 384),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ num_control_type: int = 6,
+ num_trans_channel: int = 320,
+ num_trans_head: int = 8,
+ num_trans_layer: int = 1,
+ num_proj_channel: int = 320,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is not None:
+ raise ValueError(f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None.")
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ task_scale_factor = num_trans_channel**0.5
+ self.task_embedding = nn.Parameter(task_scale_factor * torch.randn(num_control_type, num_trans_channel))
+ self.transformer_layes = nn.ModuleList(
+ [ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)]
+ )
+ self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
+ self.control_type_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.control_add_embedding = TimestepEmbedding(addition_time_embed_dim * num_control_type, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ ):
+ r"""
+ Instantiate a [`ControlNetUnionModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetUnionModel`]. All configuration options are also
+ copied where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ controlnet = cls(
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=unet.config.in_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if controlnet.class_embedding:
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: List[torch.Tensor],
+ control_type: torch.Tensor,
+ control_type_idx: List[int],
+ conditioning_scale: Union[float, List[float]] = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ from_multi: bool = False,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
+ """
+ The [`ControlNetUnionModel`] forward method.
+
+ Args:
+ sample (`torch.Tensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`List[torch.Tensor]`):
+ The conditional input tensors.
+ control_type (`torch.Tensor`):
+ A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
+ type is used.
+ control_type_idx (`List[int]`):
+ The indices of `control_type`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ from_multi (`bool`, defaults to `False`):
+ Use standard scaling when called from `MultiControlNetUnionModel`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ if isinstance(conditioning_scale, float):
+ conditioning_scale = [conditioning_scale] * len(controlnet_cond)
+
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order != "rgb":
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ else:
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ control_embeds = self.control_type_proj(control_type.flatten())
+ control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
+ control_embeds = control_embeds.to(emb.dtype)
+ control_emb = self.control_add_embedding(control_embeds)
+ emb = emb + control_emb
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ inputs = []
+ condition_list = []
+
+ for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
+ condition = self.controlnet_cond_embedding(cond)
+ feat_seq = torch.mean(condition, dim=(2, 3))
+ feat_seq = feat_seq + self.task_embedding[control_idx]
+ if from_multi:
+ inputs.append(feat_seq.unsqueeze(1))
+ condition_list.append(condition)
+ else:
+ inputs.append(feat_seq.unsqueeze(1) * scale)
+ condition_list.append(condition * scale)
+
+ condition = sample
+ feat_seq = torch.mean(condition, dim=(2, 3))
+ inputs.append(feat_seq.unsqueeze(1))
+ condition_list.append(condition)
+
+ x = torch.cat(inputs, dim=1)
+ for layer in self.transformer_layes:
+ x = layer(x)
+
+ controlnet_cond_fuser = sample * 0.0
+ for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
+ alpha = self.spatial_ch_projs(x[:, idx])
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
+ if from_multi:
+ controlnet_cond_fuser += condition + alpha
+ else:
+ controlnet_cond_fuser += condition + alpha * scale
+
+ sample = sample + controlnet_cond_fuser
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ # 5. Control net blocks
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+ if from_multi:
+ scales = scales * conditioning_scale[0]
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
+ elif from_multi:
+ down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py
similarity index 96%
rename from src/diffusers/models/controlnet_xs.py
rename to src/diffusers/models/controlnets/controlnet_xs.py
index f676a70f060a..608be6b70277 100644
--- a/src/diffusers/models/controlnet_xs.py
+++ b/src/diffusers/models/controlnets/controlnet_xs.py
@@ -19,10 +19,10 @@
import torch.utils.checkpoint
from torch import Tensor, nn
-from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput, is_torch_version, logging
-from ..utils.torch_utils import apply_freeu
-from .attention_processor import (
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import BaseOutput, logging
+from ...utils.torch_utils import apply_freeu
+from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
@@ -31,10 +31,9 @@
AttnProcessor,
FusedAttnProcessor2_0,
)
-from .controlnet import ControlNetConditioningEmbedding
-from .embeddings import TimestepEmbedding, Timesteps
-from .modeling_utils import ModelMixin
-from .unets.unet_2d_blocks import (
+from ..embeddings import TimestepEmbedding, Timesteps
+from ..modeling_utils import ModelMixin
+from ..unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
Downsample2D,
@@ -43,7 +42,8 @@
UNetMidBlock2DCrossAttn,
Upsample2D,
)
-from .unets.unet_2d_condition import UNet2DConditionModel
+from ..unets.unet_2d_condition import UNet2DConditionModel
+from .controlnet import ControlNetConditioningEmbedding
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -864,10 +864,6 @@ def freeze_unet_params(self) -> None:
for u in self.up_blocks:
u.freeze_base_params()
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -1062,7 +1058,8 @@ def forward(
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+ Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
+ tuple.
apply_control (`bool`, defaults to `True`):
If `False`, the input is run only through the base model.
@@ -1087,10 +1084,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -1448,15 +1446,6 @@ def forward(
base_blocks = list(zip(self.base_resnets, self.base_attentions))
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
):
@@ -1465,14 +1454,8 @@ def custom_forward(*inputs):
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
# apply base subblock
- if self.training and self.gradient_checkpointing:
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- h_base = torch.utils.checkpoint.checkpoint(
- create_custom_forward(b_res),
- h_base,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ h_base = self._gradient_checkpointing_func(b_res, h_base, temb)
else:
h_base = b_res(h_base, temb)
@@ -1488,14 +1471,8 @@ def custom_forward(*inputs):
# apply ctrl subblock
if apply_control:
- if self.training and self.gradient_checkpointing:
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- h_ctrl = torch.utils.checkpoint.checkpoint(
- create_custom_forward(c_res),
- h_ctrl,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb)
else:
h_ctrl = c_res(h_ctrl, temb)
if c_attn is not None:
@@ -1860,15 +1837,6 @@ def forward(
and getattr(self, "b2", None)
)
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
@@ -1897,14 +1865,8 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
- if self.training and self.gradient_checkpointing:
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
diff --git a/src/diffusers/models/controlnets/multicontrolnet.py b/src/diffusers/models/controlnets/multicontrolnet.py
new file mode 100644
index 000000000000..44bfcc1b82a9
--- /dev/null
+++ b/src/diffusers/models/controlnets/multicontrolnet.py
@@ -0,0 +1,183 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
+from ...models.modeling_utils import ModelMixin
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MultiControlNetModel(ModelMixin):
+ r"""
+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
+
+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
+ compatible with `ControlNetModel`.
+
+ Args:
+ controlnets (`List[ControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `ControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple]:
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
+ down_samples, mid_sample = controlnet(
+ sample=sample,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=image,
+ conditioning_scale=scale,
+ class_labels=class_labels,
+ timestep_cond=timestep_cond,
+ attention_mask=attention_mask,
+ added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
+ guess_mode=guess_mode,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
+ else:
+ down_block_res_samples = [
+ samples_prev + samples_curr
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
+ ]
+ mid_block_res_sample += mid_sample
+
+ return down_block_res_samples, mid_block_res_sample
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ variant: Optional[str] = None,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ variant (`str`, *optional*):
+ If specified, weights are saved in the format pytorch_model..bin.
+ """
+ for idx, controlnet in enumerate(self.nets):
+ suffix = "" if idx == 0 else f"_{idx}"
+ controlnet.save_pretrained(
+ save_directory + suffix,
+ is_main_process=is_main_process,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ variant=variant,
+ )
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_path (`os.PathLike`):
+ A path to a *directory* containing model weights saved using
+ [`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
+ `./my_model_directory/controlnet`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
+ GPU and the available CPU RAM if unset.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+ setting this argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+ ignored when using `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
+ """
+ idx = 0
+ controlnets = []
+
+ # load controlnet and append to list until no controlnet directory exists anymore
+ # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
+ # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
+ model_path_to_load = pretrained_model_path
+ while os.path.isdir(model_path_to_load):
+ controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
+ controlnets.append(controlnet)
+
+ idx += 1
+ model_path_to_load = pretrained_model_path + f"_{idx}"
+
+ logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
+
+ if len(controlnets) == 0:
+ raise ValueError(
+ f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
+ )
+
+ return cls(controlnets)
diff --git a/src/diffusers/models/controlnets/multicontrolnet_union.py b/src/diffusers/models/controlnets/multicontrolnet_union.py
new file mode 100644
index 000000000000..427e05b19110
--- /dev/null
+++ b/src/diffusers/models/controlnets/multicontrolnet_union.py
@@ -0,0 +1,196 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...models.controlnets.controlnet import ControlNetOutput
+from ...models.controlnets.controlnet_union import ControlNetUnionModel
+from ...models.modeling_utils import ModelMixin
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MultiControlNetUnionModel(ModelMixin):
+ r"""
+ Multiple `ControlNetUnionModel` wrapper class for Multi-ControlNet-Union.
+
+ This module is a wrapper for multiple instances of the `ControlNetUnionModel`. The `forward()` API is designed to
+ be compatible with `ControlNetUnionModel`.
+
+ Args:
+ controlnets (`List[ControlNetUnionModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `ControlNetUnionModel` as a list.
+ """
+
+ def __init__(self, controlnets: Union[List[ControlNetUnionModel], Tuple[ControlNetUnionModel]]):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: List[torch.tensor],
+ control_type: List[torch.Tensor],
+ control_type_idx: List[List[int]],
+ conditioning_scale: List[float],
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[ControlNetOutput, Tuple]:
+ down_block_res_samples, mid_block_res_sample = None, None
+ for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
+ zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
+ ):
+ if scale == 0.0:
+ continue
+ down_samples, mid_sample = controlnet(
+ sample=sample,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=image,
+ control_type=ctype,
+ control_type_idx=ctype_idx,
+ conditioning_scale=scale,
+ class_labels=class_labels,
+ timestep_cond=timestep_cond,
+ attention_mask=attention_mask,
+ added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
+ from_multi=True,
+ guess_mode=guess_mode,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if down_block_res_samples is None and mid_block_res_sample is None:
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
+ else:
+ down_block_res_samples = [
+ samples_prev + samples_curr
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
+ ]
+ mid_block_res_sample += mid_sample
+
+ return down_block_res_samples, mid_block_res_sample
+
+ # Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained with ControlNet->ControlNetUnion
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ variant: Optional[str] = None,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ variant (`str`, *optional*):
+ If specified, weights are saved in the format pytorch_model..bin.
+ """
+ for idx, controlnet in enumerate(self.nets):
+ suffix = "" if idx == 0 else f"_{idx}"
+ controlnet.save_pretrained(
+ save_directory + suffix,
+ is_main_process=is_main_process,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ variant=variant,
+ )
+
+ @classmethod
+ # Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained with ControlNet->ControlNetUnion
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained MultiControlNetUnion model from multiple pre-trained controlnet models.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_path (`os.PathLike`):
+ A path to a *directory* containing model weights saved using
+ [`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
+ `./my_model_directory/controlnet`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
+ GPU and the available CPU RAM if unset.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
+ setting this argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+ ignored when using `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
+ """
+ idx = 0
+ controlnets = []
+
+ # load controlnet and append to list until no controlnet directory exists anymore
+ # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
+ # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
+ model_path_to_load = pretrained_model_path
+ while os.path.isdir(model_path_to_load):
+ controlnet = ControlNetUnionModel.from_pretrained(model_path_to_load, **kwargs)
+ controlnets.append(controlnet)
+
+ idx += 1
+ model_path_to_load = pretrained_model_path + f"_{idx}"
+
+ logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
+
+ if len(controlnets) == 0:
+ raise ValueError(
+ f"No ControlNetUnions found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
+ )
+
+ return cls(controlnets)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 44f01c46ebe8..b1e14ca6a7fe 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -84,15 +84,108 @@ def get_3d_sincos_pos_embed(
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
+ device: Optional[torch.device] = None,
+ output_type: str = "np",
+) -> torch.Tensor:
+ r"""
+ Creates 3D sinusoidal positional embeddings.
+
+ Args:
+ embed_dim (`int`):
+ The embedding dimension of inputs. It must be divisible by 16.
+ spatial_size (`int` or `Tuple[int, int]`):
+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
+ spatial dimensions (height and width).
+ temporal_size (`int`):
+ The temporal dimension of postional embeddings (number of frames).
+ spatial_interpolation_scale (`float`, defaults to 1.0):
+ Scale factor for spatial grid interpolation.
+ temporal_interpolation_scale (`float`, defaults to 1.0):
+ Scale factor for temporal grid interpolation.
+
+ Returns:
+ `torch.Tensor`:
+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
+ embed_dim]`.
+ """
+ if output_type == "np":
+ return _get_3d_sincos_pos_embed_np(
+ embed_dim=embed_dim,
+ spatial_size=spatial_size,
+ temporal_size=temporal_size,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ )
+ if embed_dim % 4 != 0:
+ raise ValueError("`embed_dim` must be divisible by 4")
+ if isinstance(spatial_size, int):
+ spatial_size = (spatial_size, spatial_size)
+
+ embed_dim_spatial = 3 * embed_dim // 4
+ embed_dim_temporal = embed_dim // 4
+
+ # 1. Spatial
+ grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
+ grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
+ grid = torch.stack(grid, dim=0)
+
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
+
+ # 2. Temporal
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
+
+ # 3. Concat
+ pos_embed_spatial = pos_embed_spatial[None, :, :]
+ pos_embed_spatial = pos_embed_spatial.repeat_interleave(
+ temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
+ ) # [T, H*W, D // 4 * 3]
+
+ pos_embed_temporal = pos_embed_temporal[:, None, :]
+ pos_embed_temporal = pos_embed_temporal.repeat_interleave(
+ spatial_size[0] * spatial_size[1], dim=1
+ ) # [T, H*W, D // 4]
+
+ pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
+ return pos_embed
+
+
+def _get_3d_sincos_pos_embed_np(
+ embed_dim: int,
+ spatial_size: Union[int, Tuple[int, int]],
+ temporal_size: int,
+ spatial_interpolation_scale: float = 1.0,
+ temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
r"""
+ Creates 3D sinusoidal positional embeddings.
+
Args:
embed_dim (`int`):
+ The embedding dimension of inputs. It must be divisible by 16.
spatial_size (`int` or `Tuple[int, int]`):
+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
+ spatial dimensions (height and width).
temporal_size (`int`):
+ The temporal dimension of postional embeddings (number of frames).
spatial_interpolation_scale (`float`, defaults to 1.0):
+ Scale factor for spatial grid interpolation.
temporal_interpolation_scale (`float`, defaults to 1.0):
+ Scale factor for temporal grid interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
+ embed_dim]`.
"""
+ deprecation_message = (
+ "`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
+ " `from_numpy` is no longer required."
+ " Pass `output_type='pt' to use the new version now."
+ )
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
@@ -126,11 +219,164 @@ def get_3d_sincos_pos_embed(
def get_2d_sincos_pos_embed(
+ embed_dim,
+ grid_size,
+ cls_token=False,
+ extra_tokens=0,
+ interpolation_scale=1.0,
+ base_size=16,
+ device: Optional[torch.device] = None,
+ output_type: str = "np",
+):
+ """
+ Creates 2D sinusoidal positional embeddings.
+
+ Args:
+ embed_dim (`int`):
+ The embedding dimension.
+ grid_size (`int`):
+ The size of the grid height and width.
+ cls_token (`bool`, defaults to `False`):
+ Whether or not to add a classification token.
+ extra_tokens (`int`, defaults to `0`):
+ The number of extra tokens to add.
+ interpolation_scale (`float`, defaults to `1.0`):
+ The scale of the interpolation.
+
+ Returns:
+ pos_embed (`torch.Tensor`):
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
+ embed_dim]` if using cls_token
+ """
+ if output_type == "np":
+ deprecation_message = (
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
+ " `from_numpy` is no longer required."
+ " Pass `output_type='pt' to use the new version now."
+ )
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
+ return get_2d_sincos_pos_embed_np(
+ embed_dim=embed_dim,
+ grid_size=grid_size,
+ cls_token=cls_token,
+ extra_tokens=extra_tokens,
+ interpolation_scale=interpolation_scale,
+ base_size=base_size,
+ )
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = (
+ torch.arange(grid_size[0], device=device, dtype=torch.float32)
+ / (grid_size[0] / base_size)
+ / interpolation_scale
+ )
+ grid_w = (
+ torch.arange(grid_size[1], device=device, dtype=torch.float32)
+ / (grid_size[1] / base_size)
+ / interpolation_scale
+ )
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
+ grid = torch.stack(grid, dim=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
+ if cls_token and extra_tokens > 0:
+ pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
+ r"""
+ This function generates 2D sinusoidal positional embeddings from a grid.
+
+ Args:
+ embed_dim (`int`): The embedding dimension.
+ grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
+
+ Returns:
+ `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
+ """
+ if output_type == "np":
+ deprecation_message = (
+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
+ " `from_numpy` is no longer required."
+ " Pass `output_type='pt' to use the new version now."
+ )
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
+ return get_2d_sincos_pos_embed_from_grid_np(
+ embed_dim=embed_dim,
+ grid=grid,
+ )
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
+
+ emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
+ """
+ This function generates 1D positional embeddings from a grid.
+
+ Args:
+ embed_dim (`int`): The embedding dimension `D`
+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
+
+ Returns:
+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
+ """
+ if output_type == "np":
+ deprecation_message = (
+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
+ " `from_numpy` is no longer required."
+ " Pass `output_type='pt' to use the new version now."
+ )
+ deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
+ return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.outer(pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb
+
+
+def get_2d_sincos_pos_embed_np(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ Creates 2D sinusoidal positional embeddings.
+
+ Args:
+ embed_dim (`int`):
+ The embedding dimension.
+ grid_size (`int`):
+ The size of the grid height and width.
+ cls_token (`bool`, defaults to `False`):
+ Whether or not to add a classification token.
+ extra_tokens (`int`, defaults to `0`):
+ The number of extra tokens to add.
+ interpolation_scale (`float`, defaults to `1.0`):
+ The scale of the interpolation.
+
+ Returns:
+ pos_embed (`np.ndarray`):
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
+ embed_dim]` if using cls_token
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
@@ -141,27 +387,44 @@ def get_2d_sincos_pos_embed(
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
+ r"""
+ This function generates 2D sinusoidal positional embeddings from a grid.
+
+ Args:
+ embed_dim (`int`): The embedding dimension.
+ grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
+
+ Returns:
+ `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
+ """
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+ emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
"""
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ This function generates 1D positional embeddings from a grid.
+
+ Args:
+ embed_dim (`int`): The embedding dimension `D`
+ pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
+
+ Returns:
+ `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
@@ -181,7 +444,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
class PatchEmbed(nn.Module):
- """2D Image to Patch Embedding with support for SD3 cropping."""
+ """
+ 2D Image to Patch Embedding with support for SD3 cropping.
+
+ Args:
+ height (`int`, defaults to `224`): The height of the image.
+ width (`int`, defaults to `224`): The width of the image.
+ patch_size (`int`, defaults to `16`): The size of the patches.
+ in_channels (`int`, defaults to `3`): The number of input channels.
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
+ """
def __init__(
self,
@@ -227,10 +505,14 @@ def __init__(
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_2d_sincos_pos_embed(
- embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ embed_dim,
+ grid_size,
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ output_type="pt",
)
persistent = True if pos_embed_max_size else False
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
@@ -262,7 +544,6 @@ def forward(self, latent):
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
-
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
@@ -280,8 +561,10 @@ def forward(self, latent):
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
+ device=latent.device,
+ output_type="pt",
)
- pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
+ pos_embed = pos_embed.float().unsqueeze(0)
else:
pos_embed = self.pos_embed
@@ -289,7 +572,15 @@ def forward(self, latent):
class LuminaPatchEmbed(nn.Module):
- """2D Image to Patch Embedding with support for Lumina-T2X"""
+ """
+ 2D Image to Patch Embedding with support for Lumina-T2X
+
+ Args:
+ patch_size (`int`, defaults to `2`): The size of the patches.
+ in_channels (`int`, defaults to `4`): The number of input channels.
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
+ """
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
super().__init__()
@@ -338,6 +629,7 @@ class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
+ patch_size_t: Optional[int] = None,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
@@ -355,6 +647,7 @@ def __init__(
super().__init__()
self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
self.embed_dim = embed_dim
self.sample_height = sample_height
self.sample_width = sample_width
@@ -366,9 +659,15 @@ def __init__(
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings
- self.proj = nn.Conv2d(
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
- )
+ if patch_size_t is None:
+ # CogVideoX 1.0 checkpoints
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ else:
+ # CogVideoX 1.5 checkpoints
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings or use_learned_positional_embeddings:
@@ -376,7 +675,9 @@ def __init__(
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
- def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
+ def _get_positional_embeddings(
+ self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
+ ) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
@@ -388,9 +689,11 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
+ device=device,
+ output_type="pt",
)
- pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
- joint_pos_embedding = torch.zeros(
+ pos_embedding = pos_embedding.flatten(0, 1)
+ joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
@@ -407,12 +710,24 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
"""
text_embeds = self.text_proj(text_embeds)
- batch, num_frames, channels, height, width = image_embeds.shape
- image_embeds = image_embeds.reshape(-1, channels, height, width)
- image_embeds = self.proj(image_embeds)
- image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+ batch_size, num_frames, channels, height, width = image_embeds.shape
+
+ if self.patch_size_t is None:
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+ else:
+ p = self.patch_size
+ p_t = self.patch_size_t
+
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
+ image_embeds = image_embeds.reshape(
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
+ )
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
+ image_embeds = self.proj(image_embeds)
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
@@ -432,11 +747,13 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
- pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
- pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
+ pos_embedding = self._get_positional_embeddings(
+ height, width, pre_time_compression_frames, device=embeds.device
+ )
else:
pos_embedding = self.pos_embedding
+ pos_embedding = pos_embedding.to(dtype=embeds.dtype)
embeds = embeds + pos_embedding
return embeds
@@ -463,9 +780,11 @@ def __init__(
# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
- pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
+ pos_embed = get_2d_sincos_pos_embed(
+ hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
+ )
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
+ self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
@@ -497,7 +816,15 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
def get_3d_rotary_pos_embed(
- embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
+ embed_dim,
+ crops_coords,
+ grid_size,
+ temporal_size,
+ theta: int = 10000,
+ use_real: bool = True,
+ grid_type: str = "linspace",
+ max_size: Optional[Tuple[int, int]] = None,
+ device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
@@ -513,17 +840,36 @@ def get_3d_rotary_pos_embed(
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
+ grid_type (`str`):
+ Whether to use "linspace" or "slice" to compute grids.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
- start, stop = crops_coords
- grid_size_h, grid_size_w = grid_size
- grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
- grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
- grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
+
+ if grid_type == "linspace":
+ start, stop = crops_coords
+ grid_size_h, grid_size_w = grid_size
+ grid_h = torch.linspace(
+ start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
+ )
+ grid_w = torch.linspace(
+ start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
+ )
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
+ grid_t = torch.linspace(
+ 0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
+ )
+ elif grid_type == "slice":
+ max_h, max_w = max_size
+ grid_size_h, grid_size_w = grid_size
+ grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
+ grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
+ else:
+ raise ValueError("Invalid value passed for `grid_type`.")
# Compute dimensions for each axis
dim_t = embed_dim // 4
@@ -531,10 +877,10 @@ def get_3d_rotary_pos_embed(
dim_w = embed_dim // 8 * 3
# Temporal frequencies
- freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
# Spatial frequencies for height and width
- freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
- freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
@@ -559,12 +905,111 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
+
+ if grid_type == "slice":
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin
-def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
+def get_3d_rotary_pos_embed_allegro(
+ embed_dim,
+ crops_coords,
+ grid_size,
+ temporal_size,
+ interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
+ theta: int = 10000,
+ device: Optional[torch.device] = None,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ # TODO(aryan): docs
+ start, stop = crops_coords
+ grid_size_h, grid_size_w = grid_size
+ interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
+ grid_t = torch.linspace(
+ 0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
+ )
+ grid_h = torch.linspace(
+ start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
+ )
+ grid_w = torch.linspace(
+ start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
+ )
+
+ # Compute dimensions for each axis
+ dim_t = embed_dim // 3
+ dim_h = embed_dim // 3
+ dim_w = embed_dim // 3
+
+ # Temporal frequencies
+ freqs_t = get_1d_rotary_pos_embed(
+ dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
+ )
+ # Spatial frequencies for height and width
+ freqs_h = get_1d_rotary_pos_embed(
+ dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
+ )
+ freqs_w = get_1d_rotary_pos_embed(
+ dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
+ )
+
+ return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
+
+
+def get_2d_rotary_pos_embed(
+ embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
+):
+ """
+ RoPE for image tokens with 2d structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size
+ crops_coords (`Tuple[int]`)
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the positional embedding.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ device: (`torch.device`, **optional**):
+ The device used to create tensors.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
+ """
+ if output_type == "np":
+ deprecation_message = (
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
+ " `from_numpy` is no longer required."
+ " Pass `output_type='pt' to use the new version now."
+ )
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
+ return _get_2d_rotary_pos_embed_np(
+ embed_dim=embed_dim,
+ crops_coords=crops_coords,
+ grid_size=grid_size,
+ use_real=use_real,
+ )
+ start, stop = crops_coords
+ # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
+ grid_h = torch.linspace(
+ start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
+ )
+ grid_w = torch.linspace(
+ start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
+ )
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0) # [2, W, H]
+
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
+ return pos_embed
+
+
+def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
@@ -593,6 +1038,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
+ """
+ Get 2D RoPE from grid.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size, corresponding to hidden_size_head.
+ grid (`np.ndarray`):
+ The grid of the positional embedding.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
+ """
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
@@ -613,6 +1072,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
+ """
+ Get 2D RoPE from grid.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size, corresponding to hidden_size_head.
+ grid (`np.ndarray`):
+ The grid of the positional embedding.
+ linear_factor (`float`):
+ The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
+ layer.
+ ntk_factor (`float`):
+ The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
+ """
assert embed_dim % 4 == 0
emb_h = get_1d_rotary_pos_embed(
@@ -678,13 +1154,16 @@ def get_1d_rotary_pos_embed(
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
+ is_npu = freqs.device.type == "npu"
+ if is_npu:
+ freqs = freqs.float()
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
- # stable audio
+ # stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
@@ -725,7 +1204,7 @@ def apply_rotary_emb(
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
- # Used for Stable Audio
+ # Used for Stable Audio, OmniGen and CogView4
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
@@ -743,6 +1222,24 @@ def apply_rotary_emb(
return x_out.type_as(x)
+def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
+ # TODO(aryan): rewrite
+ def apply_1d_rope(tokens, pos, cos, sin):
+ cos = F.embedding(pos, cos)[:, None, :, :]
+ sin = F.embedding(pos, sin)[:, None, :, :]
+ x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
+ tokens_rotated = torch.cat((-x2, x1), dim=-1)
+ return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
+
+ (t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
+ t, h, w = x.chunk(3, dim=-1)
+ t = apply_1d_rope(t, positions[0], t_cos, t_sin)
+ h = apply_1d_rope(h, positions[1], h_cos, h_sin)
+ w = apply_1d_rope(w, positions[2], w_cos, w_sin)
+ x = torch.cat([t, h, w], dim=-1)
+ return x
+
+
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
@@ -756,10 +1253,16 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
- freqs_dtype = torch.float32 if is_mps else torch.float64
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
- self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
@@ -1038,7 +1541,7 @@ def forward(self, image_embeds: torch.Tensor):
batch_size = image_embeds.shape[0]
# image
- image_embeds = self.image_embeds(image_embeds)
+ image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
image_embeds = self.norm(image_embeds)
return image_embeds
@@ -1289,7 +1792,7 @@ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embeddi
def forward(self, timestep, caption_feat, caption_mask):
# timestep embedding:
time_freq = self.time_proj(timestep)
- time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
+ time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype))
# caption condition embedding:
caption_mask_float = caption_mask.float().unsqueeze(-1)
@@ -1302,6 +1805,41 @@ def forward(self, timestep, caption_feat, caption_mask):
return conditioning
+class MochiCombinedTimestepCaptionEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ pooled_projection_dim: int,
+ text_embed_dim: int,
+ time_embed_dim: int = 256,
+ num_attention_heads: int = 8,
+ ) -> None:
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
+ self.pooler = MochiAttentionPool(
+ num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
+ )
+ self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
+
+ def forward(
+ self,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ):
+ time_proj = self.time_proj(timestep)
+ time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
+
+ pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
+ caption_proj = self.caption_proj(encoder_hidden_states)
+
+ conditioning = time_emb + pooled_projections
+ return conditioning, caption_proj
+
+
class TextTimeEmbedding(nn.Module):
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
super().__init__()
@@ -1430,6 +1968,88 @@ def shape(x):
return a[:, 0, :] # cls_token
+class MochiAttentionPool(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ embed_dim: int,
+ output_dim: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+
+ self.output_dim = output_dim or embed_dim
+ self.num_attention_heads = num_attention_heads
+
+ self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
+ self.to_q = nn.Linear(embed_dim, embed_dim)
+ self.to_out = nn.Linear(embed_dim, self.output_dim)
+
+ @staticmethod
+ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
+ """
+ Pool tokens in x using mask.
+
+ NOTE: We assume x does not require gradients.
+
+ Args:
+ x: (B, L, D) tensor of tokens.
+ mask: (B, L) boolean tensor indicating which tokens are not padding.
+
+ Returns:
+ pooled: (B, D) tensor of pooled tokens.
+ """
+ assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
+ assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
+ mask = mask[:, :, None].to(dtype=x.dtype)
+ mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
+ pooled = (x * mask).sum(dim=1, keepdim=keepdim)
+ return pooled
+
+ def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+ r"""
+ Args:
+ x (`torch.Tensor`):
+ Tensor of shape `(B, S, D)` of input tokens.
+ mask (`torch.Tensor`):
+ Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
+
+ Returns:
+ `torch.Tensor`:
+ `(B, D)` tensor of pooled tokens.
+ """
+ D = x.size(2)
+
+ # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
+ attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
+ attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
+
+ # Average non-padding token features. These will be used as the query.
+ x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
+
+ # Concat pooled features to input sequence.
+ x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
+
+ # Compute queries, keys, values. Only the mean token is used to create a query.
+ kv = self.to_kv(x) # (B, L+1, 2 * D)
+ q = self.to_q(x[:, 0]) # (B, D)
+
+ # Extract heads.
+ head_dim = D // self.num_attention_heads
+ kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
+ kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
+ k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
+ q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
+ q = q.unsqueeze(2) # (B, H, 1, head_dim)
+
+ # Compute attention.
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
+
+ # Concatenate heads and run output.
+ x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
+ x = self.to_out(x)
+ return x
+
+
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:
@@ -1782,11 +2402,197 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
return out
+class IPAdapterTimeImageProjectionBlock(nn.Module):
+ """Block for IPAdapterTimeImageProjection.
+
+ Args:
+ hidden_dim (`int`, defaults to 1280):
+ The number of hidden channels.
+ dim_head (`int`, defaults to 64):
+ The number of head channels.
+ heads (`int`, defaults to 20):
+ Parallel attention heads.
+ ffn_ratio (`int`, defaults to 4):
+ The expansion ratio of feedforward network hidden layer channels.
+ """
+
+ def __init__(
+ self,
+ hidden_dim: int = 1280,
+ dim_head: int = 64,
+ heads: int = 20,
+ ffn_ratio: int = 4,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ln0 = nn.LayerNorm(hidden_dim)
+ self.ln1 = nn.LayerNorm(hidden_dim)
+ self.attn = Attention(
+ query_dim=hidden_dim,
+ cross_attention_dim=hidden_dim,
+ dim_head=dim_head,
+ heads=heads,
+ bias=False,
+ out_bias=False,
+ )
+ self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
+
+ # AdaLayerNorm
+ self.adaln_silu = nn.SiLU()
+ self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
+ self.adaln_norm = nn.LayerNorm(hidden_dim)
+
+ # Set attention scale and fuse KV
+ self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
+ self.attn.fuse_projections()
+ self.attn.to_k = None
+ self.attn.to_v = None
+
+ def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ x (`torch.Tensor`):
+ Image features.
+ latents (`torch.Tensor`):
+ Latent features.
+ timestep_emb (`torch.Tensor`):
+ Timestep embedding.
+
+ Returns:
+ `torch.Tensor`: Output latent features.
+ """
+
+ # Shift and scale for AdaLayerNorm
+ emb = self.adaln_proj(self.adaln_silu(timestep_emb))
+ shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
+
+ # Fused Attention
+ residual = latents
+ x = self.ln0(x)
+ latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+
+ batch_size = latents.shape[0]
+
+ query = self.attn.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // self.attn.heads
+
+ query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
+
+ weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ latents = weight @ value
+
+ latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
+ latents = self.attn.to_out[0](latents)
+ latents = self.attn.to_out[1](latents)
+ latents = latents + residual
+
+ ## FeedForward
+ residual = latents
+ latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ return self.ff(latents) + residual
+
+
+# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+class IPAdapterTimeImageProjection(nn.Module):
+ """Resampler of SD3 IP-Adapter with timestep embedding.
+
+ Args:
+ embed_dim (`int`, defaults to 1152):
+ The feature dimension.
+ output_dim (`int`, defaults to 2432):
+ The number of output channels.
+ hidden_dim (`int`, defaults to 1280):
+ The number of hidden channels.
+ depth (`int`, defaults to 4):
+ The number of blocks.
+ dim_head (`int`, defaults to 64):
+ The number of head channels.
+ heads (`int`, defaults to 20):
+ Parallel attention heads.
+ num_queries (`int`, defaults to 64):
+ The number of queries.
+ ffn_ratio (`int`, defaults to 4):
+ The expansion ratio of feedforward network hidden layer channels.
+ timestep_in_dim (`int`, defaults to 320):
+ The number of input channels for timestep embedding.
+ timestep_flip_sin_to_cos (`bool`, defaults to True):
+ Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
+ timestep_freq_shift (`int`, defaults to 0):
+ Controls the timestep delta between frequencies between dimensions.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int = 1152,
+ output_dim: int = 2432,
+ hidden_dim: int = 1280,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 20,
+ num_queries: int = 64,
+ ffn_ratio: int = 4,
+ timestep_in_dim: int = 320,
+ timestep_flip_sin_to_cos: bool = True,
+ timestep_freq_shift: int = 0,
+ ) -> None:
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
+ self.proj_in = nn.Linear(embed_dim, hidden_dim)
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+ self.layers = nn.ModuleList(
+ [IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
+ self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
+
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward pass.
+
+ Args:
+ x (`torch.Tensor`):
+ Image features.
+ timestep (`torch.Tensor`):
+ Timestep in denoising process.
+ Returns:
+ `Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
+ """
+ timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
+ timestep_emb = self.time_embedding(timestep_emb)
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+ x = x + timestep_emb[:, None]
+
+ for block in self.layers:
+ latents = block(x, latents, timestep_emb)
+
+ latents = self.proj_out(latents)
+ latents = self.norm_out(latents)
+
+ return latents, timestep_emb
+
+
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
+ @property
+ def num_ip_adapters(self) -> int:
+ """Number of IP-Adapters loaded."""
+ return len(self.image_projection_layers)
+
def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds = []
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index 5277ad2f9389..741f7075d76d 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,16 +17,20 @@
import importlib
import inspect
import os
+from array import array
from collections import OrderedDict
from pathlib import Path
-from typing import List, Optional, Union
+from typing import Dict, List, Optional, Union
+from zipfile import is_zipfile
import safetensors
import torch
+from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError
-from ..quantizers.quantization_config import QuantizationMethod
+from ..quantizers import DiffusersQuantizer
from ..utils import (
+ GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
@@ -34,6 +38,8 @@
_get_model_file,
deprecate,
is_accelerate_available,
+ is_gguf_available,
+ is_torch_available,
is_torch_version,
logging,
)
@@ -51,7 +57,7 @@
if is_accelerate_available():
from accelerate import infer_auto_device_map
- from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
+ from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
# Adapted from `transformers` (see modeling_utils.py)
@@ -128,25 +134,61 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class
-def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
+def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
+ """
+ Find the device of param_name from the device_map.
+ """
+ if device_map is None:
+ return "cpu"
+ else:
+ module_name = param_name
+ # find next higher level module that is defined in device_map:
+ # bert.lm_head.weight -> bert.lm_head -> bert -> ''
+ while len(module_name) > 0 and module_name not in device_map:
+ module_name = ".".join(module_name.split(".")[:-1])
+ if module_name == "" and "" not in device_map:
+ raise ValueError(f"{param_name} doesn't have any device set.")
+ return device_map[module_name]
+
+
+def load_state_dict(
+ checkpoint_file: Union[str, os.PathLike],
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
+ disable_mmap: bool = False,
+ map_location: Union[str, torch.device] = "cpu",
+):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
- # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
- # when refactoring the _merge_sharded_checkpoints() method later.
+ # TODO: maybe refactor a bit this part where we pass a dict here
if isinstance(checkpoint_file, dict):
return checkpoint_file
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
- return safetensors.torch.load_file(checkpoint_file, device="cpu")
+ if dduf_entries:
+ # tensors are loaded on cpu
+ with dduf_entries[checkpoint_file].as_mmap() as mm:
+ return safetensors.torch.load(mm)
+ if disable_mmap:
+ return safetensors.torch.load(open(checkpoint_file, "rb").read())
+ else:
+ return safetensors.torch.load_file(checkpoint_file, device=map_location)
+ elif file_extension == GGUF_FILE_EXTENSION:
+ return load_gguf_checkpoint(checkpoint_file)
else:
+ extra_args = {}
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
- return torch.load(
- checkpoint_file,
- map_location="cpu",
- **weights_only_kwarg,
- )
+ # mmap can only be used with files serialized with zipfile-based format.
+ if (
+ isinstance(checkpoint_file, str)
+ and map_location != "meta"
+ and is_torch_version(">=", "2.1.0")
+ and is_zipfile(checkpoint_file)
+ and not disable_mmap
+ ):
+ extra_args = {"mmap": True}
+ return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
except Exception as e:
try:
with open(checkpoint_file) as f:
@@ -170,21 +212,24 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
def load_model_dict_into_meta(
model,
state_dict: OrderedDict,
- device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
- hf_quantizer=None,
- keep_in_fp32_modules=None,
+ hf_quantizer: Optional[DiffusersQuantizer] = None,
+ keep_in_fp32_modules: Optional[List] = None,
+ device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
+ unexpected_keys: Optional[List[str]] = None,
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
+ offload_index: Optional[Dict] = None,
+ state_dict_index: Optional[Dict] = None,
+ state_dict_folder: Optional[Union[str, os.PathLike]] = None,
) -> List[str]:
- if hf_quantizer is None:
- device = device or torch.device("cpu")
- dtype = dtype or torch.float32
- is_quantized = hf_quantizer is not None
- is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+ """
+ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
+ params on a `meta` device. It replaces the model params with the data from the `state_dict`
+ """
- accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
+ is_quantized = hf_quantizer is not None
empty_state_dict = model.state_dict()
- unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
@@ -194,43 +239,74 @@ def load_model_dict_into_meta(
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
- if torch.is_floating_point(param):
- if (
- keep_in_fp32_modules is not None
- and any(
- module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
- )
- and dtype == torch.float16
+ if dtype is not None and torch.is_floating_point(param):
+ if keep_in_fp32_modules is not None and any(
+ module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
):
param = param.to(torch.float32)
- if accepts_dtype:
- set_module_kwargs["dtype"] = torch.float32
+ set_module_kwargs["dtype"] = torch.float32
+ # For quantizers have save weights using torch.float8_e4m3fn
+ elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
+ pass
else:
param = param.to(dtype)
- if accepts_dtype:
- set_module_kwargs["dtype"] = dtype
+ set_module_kwargs["dtype"] = dtype
- # bnb params are flattened.
- if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
- model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
- raise ValueError(
- f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
- )
+ # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
+ # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
+ # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
+ old_param = model
+ splits = param_name.split(".")
+ for split in splits:
+ old_param = getattr(old_param, split)
- if not is_quantized or (
- not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)
- ):
- if accepts_dtype:
- set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
+ if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
+ old_param = None
+
+ if old_param is not None:
+ if dtype is None:
+ param = param.to(old_param.dtype)
+
+ if old_param.is_contiguous():
+ param = param.contiguous()
+
+ param_device = _determine_param_device(param_name, device_map)
+
+ # bnb params are flattened.
+ # gguf quants have a different shape based on the type of quantization applied
+ if empty_state_dict[param_name].shape != param.shape:
+ if (
+ is_quantized
+ and hf_quantizer.pre_quantized
+ and hf_quantizer.check_if_quantized_param(
+ model, param, param_name, state_dict, param_device=param_device
+ )
+ ):
+ hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else:
- set_module_tensor_to_device(model, param_name, device, value=param)
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
+ raise ValueError(
+ f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
+ )
+ if param_device == "disk":
+ offload_index = offload_weight(param, param_name, offload_folder, offload_index)
+ elif param_device == "cpu" and state_dict_index is not None:
+ state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
+ elif is_quantized and (
+ hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
+ ):
+ hf_quantizer.create_quantized_param(
+ model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
+ )
else:
- hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
+ set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
- return unexpected_keys
+ return offload_index, state_dict_index
-def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
+def _load_state_dict_into_model(
+ model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
+) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
@@ -238,15 +314,19 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
- def load(module: torch.nn.Module, prefix: str = ""):
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
+ def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
+ local_metadata = {}
+ local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
+ if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
+ logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
- load(child, prefix + name + ".")
+ load(child, prefix + name + ".", assign_to_params_buffers)
- load(model_to_load)
+ load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
return error_msgs
@@ -265,6 +345,7 @@ def _fetch_index_file(
revision,
user_agent,
commit_hash,
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
if is_local:
index_file = Path(
@@ -290,43 +371,16 @@ def _fetch_index_file(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
+ dduf_entries=dduf_entries,
)
- index_file = Path(index_file)
+ if not dduf_entries:
+ index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
index_file = None
return index_file
-# Adapted from
-# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
-def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
- weight_map = sharded_metadata.get("weight_map", None)
- if weight_map is None:
- raise KeyError("'weight_map' key not found in the shard index file.")
-
- # Collect all unique safetensors files from weight_map
- files_to_load = set(weight_map.values())
- is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
- merged_state_dict = {}
-
- # Load tensors from each unique file
- for file_name in files_to_load:
- part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
- if not os.path.exists(part_file_path):
- raise FileNotFoundError(f"Part file {file_name} not found.")
-
- if is_safetensors:
- with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
- for tensor_key in f.keys():
- if tensor_key in weight_map:
- merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
- else:
- merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
-
- return merged_state_dict
-
-
def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
@@ -341,6 +395,7 @@ def _fetch_index_file_legacy(
revision,
user_agent,
commit_hash,
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
if is_local:
index_file = Path(
@@ -381,6 +436,7 @@ def _fetch_index_file_legacy(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
+ dduf_entries=dduf_entries,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
@@ -389,3 +445,78 @@ def _fetch_index_file_legacy(
index_file = None
return index_file
+
+
+def _gguf_parse_value(_value, data_type):
+ if not isinstance(data_type, list):
+ data_type = [data_type]
+ if len(data_type) == 1:
+ data_type = data_type[0]
+ array_data_type = None
+ else:
+ if data_type[0] != 9:
+ raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
+ data_type, array_data_type = data_type
+
+ if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
+ _value = int(_value[0])
+ elif data_type in [6, 12]:
+ _value = float(_value[0])
+ elif data_type in [7]:
+ _value = bool(_value[0])
+ elif data_type in [8]:
+ _value = array("B", list(_value)).tobytes().decode()
+ elif data_type in [9]:
+ _value = _gguf_parse_value(_value, array_data_type)
+ return _value
+
+
+def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
+ """
+ Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
+ attributes.
+
+ Args:
+ gguf_checkpoint_path (`str`):
+ The path the to GGUF file to load
+ return_tensors (`bool`, defaults to `True`):
+ Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
+ metadata in memory.
+ """
+
+ if is_gguf_available() and is_torch_available():
+ import gguf
+ from gguf import GGUFReader
+
+ from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
+ else:
+ logger.error(
+ "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
+ "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
+ )
+ raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
+
+ reader = GGUFReader(gguf_checkpoint_path)
+
+ parsed_parameters = {}
+ for tensor in reader.tensors:
+ name = tensor.name
+ quant_type = tensor.tensor_type
+
+ # if the tensor is a torch supported dtype do not use GGUFParameter
+ is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
+ if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
+ _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
+ raise ValueError(
+ (
+ f"{name} has a quantization type: {str(quant_type)} which is unsupported."
+ "\n\nCurrently the following quantization types are supported: \n\n"
+ f"{_supported_quants_str}"
+ "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
+ )
+ )
+
+ weights = torch.from_numpy(tensor.data.copy())
+ parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
+
+ return parsed_parameters
diff --git a/src/diffusers/models/modeling_flax_pytorch_utils.py b/src/diffusers/models/modeling_flax_pytorch_utils.py
index 4db537f54b94..d64c48a9601e 100644
--- a/src/diffusers/models/modeling_flax_pytorch_utils.py
+++ b/src/diffusers/models/modeling_flax_pytorch_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py
index 8c35fab0fc16..52f004f6f93f 100644
--- a/src/diffusers/models/modeling_flax_utils.py
+++ b/src/diffusers/models/modeling_flax_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -530,7 +530,7 @@ def save_pretrained(
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
- private = kwargs.pop("private", False)
+ private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
diff --git a/src/diffusers/models/modeling_pytorch_flax_utils.py b/src/diffusers/models/modeling_pytorch_flax_utils.py
index 55eff0e1ed54..ada55073dd55 100644
--- a/src/diffusers/models/modeling_pytorch_flax_utils.py
+++ b/src/diffusers/models/modeling_pytorch_flax_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 4a486fd4ce40..19ac868cdae0 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,16 +20,21 @@
import json
import os
import re
+import shutil
+import tempfile
from collections import OrderedDict
-from functools import partial, wraps
+from contextlib import ExitStack, contextmanager
+from functools import wraps
from pathlib import Path
-from typing import Any, Callable, List, Optional, Tuple, Union
+from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Type, Union
import safetensors
import torch
-from huggingface_hub import create_repo, split_torch_state_dict_into_shards
+import torch.utils.checkpoint
+from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn
+from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
@@ -48,6 +53,7 @@
is_accelerate_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
+ is_peft_available,
is_torch_version,
logging,
)
@@ -61,16 +67,49 @@
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
- _merge_sharded_checkpoints,
load_model_dict_into_meta,
load_state_dict,
)
+class ContextManagers:
+ """
+ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
+ in the `fastcore` library.
+ """
+
+ def __init__(self, context_managers: List[ContextManager]):
+ self.context_managers = context_managers
+ self.stack = ExitStack()
+
+ def __enter__(self):
+ for context_manager in self.context_managers:
+ self.stack.enter_context(context_manager)
+
+ def __exit__(self, *args, **kwargs):
+ self.stack.__exit__(*args, **kwargs)
+
+
logger = logging.get_logger(__name__)
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
+TORCH_INIT_FUNCTIONS = {
+ "uniform_": nn.init.uniform_,
+ "normal_": nn.init.normal_,
+ "trunc_normal_": nn.init.trunc_normal_,
+ "constant_": nn.init.constant_,
+ "xavier_uniform_": nn.init.xavier_uniform_,
+ "xavier_normal_": nn.init.xavier_normal_,
+ "kaiming_uniform_": nn.init.kaiming_uniform_,
+ "kaiming_normal_": nn.init.kaiming_normal_,
+ "uniform": nn.init.uniform,
+ "normal": nn.init.normal,
+ "xavier_uniform": nn.init.xavier_uniform,
+ "xavier_normal": nn.init.xavier_normal,
+ "kaiming_uniform": nn.init.kaiming_uniform,
+ "kaiming_normal": nn.init.kaiming_normal,
+}
if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
@@ -80,10 +119,22 @@
if is_accelerate_available():
import accelerate
+ from accelerate import dispatch_model
+ from accelerate.utils import load_offloaded_weights, save_offload_index
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
+ from ..hooks.group_offloading import _get_group_onload_device
+
+ try:
+ # Try to get the onload device from the group offloading hook
+ return _get_group_onload_device(parameter)
+ except ValueError:
+ pass
+
try:
+ # If the onload device is not available due to no group offloading hooks, try to get the device
+ # from the first parameter or buffer
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
except StopIteration:
@@ -99,21 +150,102 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
- try:
- return next(parameter.parameters()).dtype
- except StopIteration:
- try:
- return next(parameter.buffers()).dtype
- except StopIteration:
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
+ """
+ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
+ """
+ # 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting)
+ if isinstance(parameter, nn.Module):
+ for name, submodule in parameter.named_modules():
+ if not hasattr(submodule, "_diffusers_hook"):
+ continue
+ registry = submodule._diffusers_hook
+ hook = registry.get_hook("layerwise_casting")
+ if hook is not None:
+ return hook.compute_dtype
+
+ # 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
+ last_dtype = None
+
+ for name, param in parameter.named_parameters():
+ last_dtype = param.dtype
+ if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
+ continue
+
+ if param.is_floating_point():
+ return param.dtype
+
+ for buffer in parameter.buffers():
+ last_dtype = buffer.dtype
+ if buffer.is_floating_point():
+ return buffer.dtype
+
+ if last_dtype is not None:
+ # if no floating dtype was found return whatever the first dtype is
+ return last_dtype
+
+ # For nn.DataParallel compatibility in PyTorch > 1.5
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ last_tuple = None
+ for tuple in gen:
+ last_tuple = tuple
+ if tuple[1].is_floating_point():
+ return tuple[1].dtype
+
+ if last_tuple is not None:
+ # fallback to the last dtype
+ return last_tuple[1].dtype
+
+
+def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
+ """
+ Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
+ checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
+ parameters.
+
+ """
+ if model_to_load.device.type == "meta":
+ return False
+
+ if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
+ return False
+
+ # Some models explicitly do not support param buffer assignment
+ if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
+ logger.debug(
+ f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
+ )
+ return False
+
+ # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
+ first_key = next(iter(model_to_load.state_dict().keys()))
+ if start_prefix + first_key in state_dict:
+ return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
- return tuples
+ return False
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
- first_tuple = next(gen)
- return first_tuple[1].dtype
+
+@contextmanager
+def no_init_weights():
+ """
+ Context manager to globally disable weight initialization to speed up loading large models. To do that, all the
+ torch.nn.init function are all replaced with skip.
+ """
+
+ def _skip_init(*args, **kwargs):
+ pass
+
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
+ setattr(torch.nn.init, name, _skip_init)
+ try:
+ yield
+ finally:
+ # Restore the original initialization functions
+ for name, init_func in TORCH_INIT_FUNCTIONS.items():
+ setattr(torch.nn.init, name, init_func)
class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -132,10 +264,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keys_to_ignore_on_load_unexpected = None
_no_split_modules = None
_keep_in_fp32_modules = None
+ _skip_layerwise_casting_patterns = None
+ _supports_group_offloading = True
def __init__(self):
super().__init__()
+ self._gradient_checkpointing_func = None
+
def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
@@ -161,14 +297,35 @@ def is_gradient_checkpointing(self) -> bool:
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
- def enable_gradient_checkpointing(self) -> None:
+ def enable_gradient_checkpointing(self, gradient_checkpointing_func: Optional[Callable] = None) -> None:
"""
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
+
+ Args:
+ gradient_checkpointing_func (`Callable`, *optional*):
+ The function to use for gradient checkpointing. If `None`, the default PyTorch checkpointing function
+ is used (`torch.utils.checkpoint.checkpoint`).
"""
if not self._supports_gradient_checkpointing:
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
- self.apply(partial(self._set_gradient_checkpointing, value=True))
+ raise ValueError(
+ f"{self.__class__.__name__} does not support gradient checkpointing. Please make sure to set the boolean attribute "
+ f"`_supports_gradient_checkpointing` to `True` in the class definition."
+ )
+
+ if gradient_checkpointing_func is None:
+
+ def _gradient_checkpointing_func(module, *args):
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ return torch.utils.checkpoint.checkpoint(
+ module.__call__,
+ *args,
+ **ckpt_kwargs,
+ )
+
+ gradient_checkpointing_func = _gradient_checkpointing_func
+
+ self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
def disable_gradient_checkpointing(self) -> None:
"""
@@ -176,7 +333,7 @@ def disable_gradient_checkpointing(self) -> None:
*checkpoint activations* in other frameworks).
"""
if self._supports_gradient_checkpointing:
- self.apply(partial(self._set_gradient_checkpointing, value=False))
+ self._set_gradient_checkpointing(enable=False)
def set_use_npu_flash_attention(self, valid: bool) -> None:
r"""
@@ -208,6 +365,35 @@ def disable_npu_flash_attention(self) -> None:
"""
self.set_use_npu_flash_attention(False)
+ def set_use_xla_flash_attention(
+ self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None, **kwargs
+ ) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_xla_flash_attention method
+ # gets the message
+ def fn_recursive_set_flash_attention(module: torch.nn.Module):
+ if hasattr(module, "set_use_xla_flash_attention"):
+ module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec, **kwargs)
+
+ for child in module.children():
+ fn_recursive_set_flash_attention(child)
+
+ for module in self.children():
+ if isinstance(module, torch.nn.Module):
+ fn_recursive_set_flash_attention(module)
+
+ def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None, **kwargs):
+ r"""
+ Enable the flash attention pallals kernel for torch_xla.
+ """
+ self.set_use_xla_flash_attention(True, partition_spec, **kwargs)
+
+ def disable_xla_flash_attention(self):
+ r"""
+ Disable the flash attention pallals kernel for torch_xla.
+ """
+ self.set_use_xla_flash_attention(False)
+
def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
@@ -267,6 +453,150 @@ def disable_xformers_memory_efficient_attention(self) -> None:
"""
self.set_use_memory_efficient_attention_xformers(False)
+ def enable_layerwise_casting(
+ self,
+ storage_dtype: torch.dtype = torch.float8_e4m3fn,
+ compute_dtype: Optional[torch.dtype] = None,
+ skip_modules_pattern: Optional[Tuple[str, ...]] = None,
+ skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
+ non_blocking: bool = False,
+ ) -> None:
+ r"""
+ Activates layerwise casting for the current model.
+
+ Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but
+ upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the
+ memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations
+ are negligible, mostly stemming from weight casting in normalization and modulation layers.
+
+ By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch
+ embedding, positional embedding and normalization layers. This is because these layers are most likely
+ precision-critical for quality. If you wish to change this behavior, you can set the
+ `_skip_layerwise_casting_patterns` attribute to `None`, or call
+ [`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments.
+
+ Example:
+ Using [`~models.ModelMixin.enable_layerwise_casting`]:
+
+ ```python
+ >>> from diffusers import CogVideoXTransformer3DModel
+
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
+ ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+
+ >>> # Enable layerwise casting via the model, which ignores certain modules by default
+ >>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
+ ```
+
+ Args:
+ storage_dtype (`torch.dtype`):
+ The dtype to which the model should be cast for storage.
+ compute_dtype (`torch.dtype`):
+ The dtype to which the model weights should be cast during the forward pass.
+ skip_modules_pattern (`Tuple[str, ...]`, *optional*):
+ A list of patterns to match the names of the modules to skip during the layerwise casting process. If
+ set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
+ layers.
+ skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
+ A list of module classes to skip during the layerwise casting process.
+ non_blocking (`bool`, *optional*, defaults to `False`):
+ If `True`, the weight casting operations are non-blocking.
+ """
+ from ..hooks import apply_layerwise_casting
+
+ user_provided_patterns = True
+ if skip_modules_pattern is None:
+ from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
+
+ skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
+ user_provided_patterns = False
+ if self._keep_in_fp32_modules is not None:
+ skip_modules_pattern += tuple(self._keep_in_fp32_modules)
+ if self._skip_layerwise_casting_patterns is not None:
+ skip_modules_pattern += tuple(self._skip_layerwise_casting_patterns)
+ skip_modules_pattern = tuple(set(skip_modules_pattern))
+
+ if is_peft_available() and not user_provided_patterns:
+ # By default, we want to skip all peft layers because they have a very low memory footprint.
+ # If users want to apply layerwise casting on peft layers as well, they can utilize the
+ # `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides
+ # them with more flexibility and control.
+
+ from peft.tuners.loha.layer import LoHaLayer
+ from peft.tuners.lokr.layer import LoKrLayer
+ from peft.tuners.lora.layer import LoraLayer
+
+ for layer in (LoHaLayer, LoKrLayer, LoraLayer):
+ skip_modules_pattern += tuple(layer.adapter_layer_names)
+
+ if compute_dtype is None:
+ logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.")
+ compute_dtype = self.dtype
+
+ apply_layerwise_casting(
+ self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
+ )
+
+ def enable_group_offload(
+ self,
+ onload_device: torch.device,
+ offload_device: torch.device = torch.device("cpu"),
+ offload_type: str = "block_level",
+ num_blocks_per_group: Optional[int] = None,
+ non_blocking: bool = False,
+ use_stream: bool = False,
+ low_cpu_mem_usage=False,
+ ) -> None:
+ r"""
+ Activates group offloading for the current model.
+
+ See [`~hooks.group_offloading.apply_group_offloading`] for more information.
+
+ Example:
+
+ ```python
+ >>> from diffusers import CogVideoXTransformer3DModel
+
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
+ ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+
+ >>> transformer.enable_group_offload(
+ ... onload_device=torch.device("cuda"),
+ ... offload_device=torch.device("cpu"),
+ ... offload_type="leaf_level",
+ ... use_stream=True,
+ ... )
+ ```
+ """
+ from ..hooks import apply_group_offloading
+
+ if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
+ msg = (
+ "Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
+ "forward pass is executed with tiling enabled. Please make sure to either:\n"
+ "1. Run a forward pass with small input shapes.\n"
+ "2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
+ )
+ logger.warning(msg)
+ if not self._supports_group_offloading:
+ raise ValueError(
+ f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
+ f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
+ f"open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+ apply_group_offloading(
+ self,
+ onload_device,
+ offload_device,
+ offload_type,
+ num_blocks_per_group,
+ non_blocking,
+ use_stream,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -338,7 +668,7 @@ def save_pretrained(
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
- private = kwargs.pop("private", False)
+ private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -379,7 +709,7 @@ def save_pretrained(
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
- shard = {tensor: state_dict[tensor] for tensor in tensors}
+ shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
@@ -436,7 +766,7 @@ def dequantize(self):
@classmethod
@validate_hf_hub_args
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
r"""
Instantiate a pretrained PyTorch model from a pretrained model configuration.
@@ -512,6 +842,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
+ disable_mmap ('bool', *optional*, defaults to 'False'):
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -552,11 +885,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
- offload_state_dict = kwargs.pop("offload_state_dict", False)
+ offload_state_dict = kwargs.pop("offload_state_dict", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None)
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
+ disable_mmap = kwargs.pop("disable_mmap", False)
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ torch_dtype = torch.float32
+ logger.warning(
+ f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+ )
allow_pickle = False
if use_safetensors is None:
@@ -627,14 +968,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
- # Load config if we don't provide a configuration
- config_path = pretrained_model_name_or_path
-
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
+ unused_kwargs = {}
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
# load config
config, unused_kwargs, commit_hash = cls.load_config(
@@ -649,6 +991,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
+ dduf_entries=dduf_entries,
**kwargs,
)
# no in-place modification of the original config.
@@ -671,12 +1014,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
hf_quantizer = None
if hf_quantizer is not None:
- if device_map is not None:
- raise NotImplementedError(
- "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
- )
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
+ device_map = hf_quantizer.update_device_map(device_map)
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
@@ -689,9 +1029,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
# Check if `_keep_in_fp32_modules` is not None
- use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
- (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
+ use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and (
+ hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
)
+
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
@@ -704,10 +1045,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
else:
keep_in_fp32_modules = []
- #######################################
- # Determine if we're loading from a directory of sharded checkpoints.
is_sharded = False
+ resolved_model_file = None
+
+ # Determine if we're loading from a directory of sharded checkpoints.
+ sharded_metadata = None
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file_kwargs = {
@@ -724,22 +1067,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"revision": revision,
"user_agent": user_agent,
"commit_hash": commit_hash,
+ "dduf_entries": dduf_entries,
}
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
# this becomes applicable when the variant is not None.
if variant is not None and (index_file is None or not os.path.exists(index_file)):
index_file = _fetch_index_file_legacy(**index_file_kwargs)
- if index_file is not None and index_file.is_file():
+ if index_file is not None and (dduf_entries or index_file.is_file()):
is_sharded = True
if is_sharded and from_flax:
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
# load model
- model_file = None
if from_flax:
- model_file = _get_model_file(
+ resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir,
@@ -757,10 +1100,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Convert the weights
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
- model = load_flax_checkpoint_in_pytorch_model(model, model_file)
+ model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
else:
+ # in the case it is sharded, we have already the index
if is_sharded:
- sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
+ resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
@@ -770,15 +1114,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
+ dduf_entries=dduf_entries,
)
- if hf_quantizer is not None:
- model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
- logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
- is_sharded = False
-
- elif use_safetensors and not is_sharded:
+ elif use_safetensors:
try:
- model_file = _get_model_file(
+ resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
@@ -790,6 +1130,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
+ dduf_entries=dduf_entries,
)
except IOError as e:
@@ -800,8 +1141,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
- if model_file is None and not is_sharded:
- model_file = _get_model_file(
+ if resolved_model_file is None and not is_sharded:
+ resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
@@ -813,159 +1154,107 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
+ dduf_entries=dduf_entries,
)
- if low_cpu_mem_usage:
- # Instantiate model with empty weights
- with accelerate.init_empty_weights():
- model = cls.from_config(config, **unused_kwargs)
+ if not isinstance(resolved_model_file, list):
+ resolved_model_file = [resolved_model_file]
- if hf_quantizer is not None:
- hf_quantizer.preprocess_model(
- model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
- )
+ # set dtype to instantiate the model under:
+ # 1. If torch_dtype is not None, we use that dtype
+ # 2. If torch_dtype is float8, we don't use _set_default_torch_dtype and we downcast after loading the model
+ dtype_orig = None
+ if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
+ if not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ dtype_orig = cls._set_default_torch_dtype(torch_dtype)
- # if device_map is None, load the state dict and move the params from meta device to the cpu
- if device_map is None and not is_sharded:
- # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
- # It would error out during the `validate_environment()` call above in the absence of cuda.
- is_quant_method_bnb = (
- getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
- )
- if hf_quantizer is None:
- param_device = "cpu"
- # TODO (sayakpaul, SunMarc): remove this after model loading refactor
- elif is_quant_method_bnb:
- param_device = torch.cuda.current_device()
- state_dict = load_state_dict(model_file, variant=variant)
- model._convert_deprecated_attention_blocks(state_dict)
-
- # move the params from meta device to cpu
- missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
- if hf_quantizer is not None:
- missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
- if len(missing_keys) > 0:
- raise ValueError(
- f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
- f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
- " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
- " those weights or else make sure your checkpoint file is correct."
- )
+ init_contexts = [no_init_weights()]
- unexpected_keys = load_model_dict_into_meta(
- model,
- state_dict,
- device=param_device,
- dtype=torch_dtype,
- model_name_or_path=pretrained_model_name_or_path,
- hf_quantizer=hf_quantizer,
- keep_in_fp32_modules=keep_in_fp32_modules,
- )
+ if low_cpu_mem_usage:
+ init_contexts.append(accelerate.init_empty_weights())
- if cls._keys_to_ignore_on_load_unexpected is not None:
- for pat in cls._keys_to_ignore_on_load_unexpected:
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+ with ContextManagers(init_contexts):
+ model = cls.from_config(config, **unused_kwargs)
- if len(unexpected_keys) > 0:
- logger.warning(
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
- )
+ if dtype_orig is not None:
+ torch.set_default_dtype(dtype_orig)
- else: # else let accelerate handle loading and dispatching.
- # Load weights and dispatch according to the device_map
- # by default the device_map is None and the weights are loaded on the CPU
- force_hook = True
- device_map = _determine_device_map(
- model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
- )
- if device_map is None and is_sharded:
- # we load the parameters on the cpu
- device_map = {"": "cpu"}
- force_hook = False
- try:
- accelerate.load_checkpoint_and_dispatch(
- model,
- model_file if not is_sharded else index_file,
- device_map,
- max_memory=max_memory,
- offload_folder=offload_folder,
- offload_state_dict=offload_state_dict,
- dtype=torch_dtype,
- force_hooks=force_hook,
- strict=True,
- )
- except AttributeError as e:
- # When using accelerate loading, we do not have the ability to load the state
- # dict and rename the weight names manually. Additionally, accelerate skips
- # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
- # (which look like they should be private variables?), so we can't use the standard hooks
- # to rename parameters on load. We need to mimic the original weight names so the correct
- # attributes are available. After we have loaded the weights, we convert the deprecated
- # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
- # the weights so we don't have to do this again.
-
- if "'Attention' object has no attribute" in str(e):
- logger.warning(
- f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
- " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
- " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
- " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
- " please also re-upload it or open a PR on the original repository."
- )
- model._temp_convert_self_to_deprecated_attention_blocks()
- accelerate.load_checkpoint_and_dispatch(
- model,
- model_file if not is_sharded else index_file,
- device_map,
- max_memory=max_memory,
- offload_folder=offload_folder,
- offload_state_dict=offload_state_dict,
- dtype=torch_dtype,
- force_hooks=force_hook,
- strict=True,
- )
- model._undo_temp_convert_self_to_deprecated_attention_blocks()
- else:
- raise e
-
- loading_info = {
- "missing_keys": [],
- "unexpected_keys": [],
- "mismatched_keys": [],
- "error_msgs": [],
- }
- else:
- model = cls.from_config(config, **unused_kwargs)
+ state_dict = None
+ if not is_sharded:
+ # Time to load the checkpoint
+ state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
+ # We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
+ model._fix_state_dict_keys_on_load(state_dict)
- state_dict = load_state_dict(model_file, variant=variant)
- model._convert_deprecated_attention_blocks(state_dict)
+ if is_sharded:
+ loaded_keys = sharded_metadata["all_checkpoint_keys"]
+ else:
+ loaded_keys = list(state_dict.keys())
- model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
- model,
- state_dict,
- model_file,
- pretrained_model_name_or_path,
- ignore_mismatched_sizes=ignore_mismatched_sizes,
- )
+ if hf_quantizer is not None:
+ hf_quantizer.preprocess_model(
+ model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
+ )
- loading_info = {
- "missing_keys": missing_keys,
- "unexpected_keys": unexpected_keys,
- "mismatched_keys": mismatched_keys,
- "error_msgs": error_msgs,
- }
+ # Now that the model is loaded, we can determine the device_map
+ device_map = _determine_device_map(
+ model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
+ )
+ if hf_quantizer is not None:
+ hf_quantizer.validate_environment(device_map=device_map)
+
+ (
+ model,
+ missing_keys,
+ unexpected_keys,
+ mismatched_keys,
+ offload_index,
+ error_msgs,
+ ) = cls._load_pretrained_model(
+ model,
+ state_dict,
+ resolved_model_file,
+ pretrained_model_name_or_path,
+ loaded_keys,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ device_map=device_map,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ dtype=torch_dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ dduf_entries=dduf_entries,
+ )
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+
+ # Dispatch model with hooks on all devices if necessary
+ if device_map is not None:
+ device_map_kwargs = {
+ "device_map": device_map,
+ "offload_dir": offload_folder,
+ "offload_index": offload_index,
+ }
+ dispatch_model(model, **device_map_kwargs)
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
- raise ValueError(
- f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
- )
- # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
- # completely lose the effectivity of `use_keep_in_fp32_modules`.
- elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
+ if (
+ torch_dtype is not None
+ and torch_dtype == getattr(torch, "float8_e4m3fn", None)
+ and hf_quantizer is None
+ and not use_keep_in_fp32_modules
+ ):
model = model.to(torch_dtype)
if hf_quantizer is not None:
@@ -977,6 +1266,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
+
if output_loading_info:
return model, loading_info
@@ -985,6 +1275,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Adapted from `transformers`.
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
+ from ..hooks.group_offloading import _is_group_offload_enabled
+
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
@@ -997,27 +1289,48 @@ def cuda(self, *args, **kwargs):
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
+
+ # Checks if group offloading is enabled
+ if _is_group_offload_enabled(self):
+ logger.warning(
+ f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported."
+ )
+ return self
+
return super().cuda(*args, **kwargs)
# Adapted from `transformers`.
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
+ from ..hooks.group_offloading import _is_group_offload_enabled
+
+ device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
dtype_present_in_args = "dtype" in kwargs
+ # Try converting arguments to torch.device in case they are passed as strings
+ for arg in args:
+ if not isinstance(arg, str):
+ continue
+ try:
+ torch.device(arg)
+ device_arg_or_kwarg_present = True
+ except RuntimeError:
+ pass
+
if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break
- # Checks if the model has been loaded in 4-bit or 8-bit with BNB
- if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+ if getattr(self, "is_quantized", False):
if dtype_present_in_args:
raise ValueError(
- "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
- " desired `dtype` by passing the correct `torch_dtype` argument."
+ "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
+ "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
)
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
@@ -1028,6 +1341,13 @@ def to(self, *args, **kwargs):
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
+
+ if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
+ logger.warning(
+ f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
+ )
+ return self
+
return super().to(*args, **kwargs)
# Taken from `transformers`.
@@ -1057,54 +1377,127 @@ def _load_pretrained_model(
cls,
model,
state_dict: OrderedDict,
- resolved_archive_file,
+ resolved_model_file: List[str],
pretrained_model_name_or_path: Union[str, os.PathLike],
+ loaded_keys: List[str],
ignore_mismatched_sizes: bool = False,
+ assign_to_params_buffers: bool = False,
+ hf_quantizer: Optional[DiffusersQuantizer] = None,
+ low_cpu_mem_usage: bool = True,
+ dtype: Optional[Union[str, torch.dtype]] = None,
+ keep_in_fp32_modules: Optional[List[str]] = None,
+ device_map: Dict[str, Union[int, str, torch.device]] = None,
+ offload_state_dict: Optional[bool] = None,
+ offload_folder: Optional[Union[str, os.PathLike]] = None,
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
- # Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
- loaded_keys = list(state_dict.keys())
-
expected_keys = list(model_state_dict.keys())
-
- original_loaded_keys = loaded_keys
-
missing_keys = list(set(expected_keys) - set(loaded_keys))
+ if hf_quantizer is not None:
+ missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+ # Some models may have keys that are not in the state by design, removing them before needlessly warning
+ # the user.
+ if cls._keys_to_ignore_on_load_unexpected is not None:
+ for pat in cls._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
- # Make sure we are able to load base models as well as derived models (with heads)
- model_to_load = model
+ mismatched_keys = []
- def _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
- ):
- mismatched_keys = []
- if ignore_mismatched_sizes:
- for checkpoint_key in loaded_keys:
- model_key = checkpoint_key
-
- if (
- model_key in model_state_dict
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
- ):
- mismatched_keys.append(
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
- )
- del state_dict[checkpoint_key]
- return mismatched_keys
+ assign_to_params_buffers = None
+ error_msgs = []
+
+ # Deal with offload
+ if device_map is not None and "disk" in device_map.values():
+ if offload_folder is None:
+ raise ValueError(
+ "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
+ " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
+ " offers the weights in this format."
+ )
+ if offload_folder is not None:
+ os.makedirs(offload_folder, exist_ok=True)
+ if offload_state_dict is None:
+ offload_state_dict = True
+
+ offload_index = {} if device_map is not None and "disk" in device_map.values() else None
+ if offload_state_dict:
+ state_dict_folder = tempfile.mkdtemp()
+ state_dict_index = {}
+ else:
+ state_dict_folder = None
+ state_dict_index = None
if state_dict is not None:
- # Whole checkpoint
- mismatched_keys = _find_mismatched_keys(
+ # load_state_dict will manage the case where we pass a dict instead of a file
+ # if state dict is not None, it means that we don't need to read the files from resolved_model_file also
+ resolved_model_file = [state_dict]
+
+ if len(resolved_model_file) > 1:
+ resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
+
+ for shard_file in resolved_model_file:
+ state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+ # If the checkpoint is sharded, we may not have the key here.
+ if checkpoint_key not in state_dict:
+ continue
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
- original_loaded_keys,
+ loaded_keys,
ignore_mismatched_sizes,
)
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if low_cpu_mem_usage:
+ offload_index, state_dict_index = load_model_dict_into_meta(
+ model,
+ state_dict,
+ device_map=device_map,
+ dtype=dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ unexpected_keys=unexpected_keys,
+ offload_folder=offload_folder,
+ offload_index=offload_index,
+ state_dict_index=state_dict_index,
+ state_dict_folder=state_dict_folder,
+ )
+ else:
+ if assign_to_params_buffers is None:
+ assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
+
+ error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
+
+ if offload_index is not None and len(offload_index) > 0:
+ save_offload_index(offload_index, offload_folder)
+ offload_index = None
+
+ if offload_state_dict:
+ load_offloaded_weights(model, state_dict_index, state_dict_folder)
+ shutil.rmtree(state_dict_folder)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
@@ -1116,17 +1509,11 @@ def _find_mismatched_keys(
if len(unexpected_keys) > 0:
logger.warning(
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
- " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
- " identical (initializing a BertForSequenceClassification model from a"
- " BertForSequenceClassification model)."
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
@@ -1154,7 +1541,7 @@ def _find_mismatched_keys(
" able to use it for predictions and inference."
)
- return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+ return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
@classmethod
def _get_signature_keys(cls, obj):
@@ -1168,7 +1555,7 @@ def _get_signature_keys(cls, obj):
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
- Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
+ Get the modules of the model that should not be split when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
@@ -1195,6 +1582,33 @@ def _get_no_split_modules(self, device_map: str):
modules_to_check += list(module.children())
return list(_no_split_modules)
+ @classmethod
+ def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
+ """
+ Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
+ under specific dtype.
+
+ Args:
+ dtype (`torch.dtype`):
+ a floating dtype to set to.
+
+ Returns:
+ `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
+ modified. If it wasn't, returns `None`.
+
+ Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
+ `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
+ """
+ if not dtype.is_floating_point:
+ raise ValueError(
+ f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
+ )
+
+ logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
+ dtype_orig = torch.get_default_dtype()
+ torch.set_default_dtype(dtype)
+ return dtype_orig
+
@property
def device(self) -> torch.device:
"""
@@ -1292,7 +1706,31 @@ def get_memory_footprint(self, return_buffers=True):
mem = mem + mem_bufs
return mem
- def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
+ def _set_gradient_checkpointing(
+ self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
+ ) -> None:
+ is_gradient_checkpointing_set = False
+
+ for name, module in self.named_modules():
+ if hasattr(module, "gradient_checkpointing"):
+ logger.debug(f"Setting `gradient_checkpointing={enable}` for '{name}'")
+ module._gradient_checkpointing_func = gradient_checkpointing_func
+ module.gradient_checkpointing = enable
+ is_gradient_checkpointing_set = True
+
+ if not is_gradient_checkpointing_set:
+ raise ValueError(
+ f"The module {self.__class__.__name__} does not support gradient checkpointing. Please make sure to "
+ f"use a module that supports gradient checkpointing by creating a boolean attribute `gradient_checkpointing`."
+ )
+
+ def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
+ """
+ This function fix the state dict of the model to take into account some changes that were made in the model
+ architecture:
+ - deprecated attention blocks (happened before we introduced sharded checkpoint,
+ so this is why we apply this method only when loading non sharded checkpoints for now)
+ """
deprecated_attention_block_paths = []
def recursive_find_attn_block(name, module):
@@ -1335,56 +1773,7 @@ def recursive_find_attn_block(name, module):
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
-
- def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
- deprecated_attention_block_modules = []
-
- def recursive_find_attn_block(module):
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
- deprecated_attention_block_modules.append(module)
-
- for sub_module in module.children():
- recursive_find_attn_block(sub_module)
-
- recursive_find_attn_block(self)
-
- for module in deprecated_attention_block_modules:
- module.query = module.to_q
- module.key = module.to_k
- module.value = module.to_v
- module.proj_attn = module.to_out[0]
-
- # We don't _have_ to delete the old attributes, but it's helpful to ensure
- # that _all_ the weights are loaded into the new attributes and we're not
- # making an incorrect assumption that this model should be converted when
- # it really shouldn't be.
- del module.to_q
- del module.to_k
- del module.to_v
- del module.to_out
-
- def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
- deprecated_attention_block_modules = []
-
- def recursive_find_attn_block(module) -> None:
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
- deprecated_attention_block_modules.append(module)
-
- for sub_module in module.children():
- recursive_find_attn_block(sub_module)
-
- recursive_find_attn_block(self)
-
- for module in deprecated_attention_block_modules:
- module.to_q = module.query
- module.to_k = module.key
- module.to_v = module.value
- module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
-
- del module.query
- del module.key
- del module.value
- del module.proj_attn
+ return state_dict
class LegacyModelMixin(ModelMixin):
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 029c147fcbac..962ce435bdb7 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -20,12 +20,9 @@
import torch.nn as nn
import torch.nn.functional as F
-from ..utils import is_torch_version
+from ..utils import is_torch_npu_available, is_torch_version
from .activations import get_activation
-from .embeddings import (
- CombinedTimestepLabelEmbeddings,
- PixArtAlphaCombinedTimestepSizeEmbeddings,
-)
+from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
class AdaLayerNorm(nn.Module):
@@ -74,7 +71,7 @@ def forward(
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
- # other if-branch. This branch is specific to CogVideoX for now.
+ # other if-branch. This branch is specific to CogVideoX and OmniGen for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
@@ -222,14 +219,13 @@ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine:
4 * embedding_dim,
bias=True,
)
- self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- # emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
@@ -266,6 +262,7 @@ def forward(
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
+ added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
@@ -309,6 +306,20 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
class AdaLayerNormContinuous(nn.Module):
+ r"""
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
+
+ Args:
+ embedding_dim (`int`): Embedding dimension to use during projection.
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
+ elementwise_affine (`bool`, defaults to `True`):
+ Boolean flag to denote if affine transformation should be applied.
+ eps (`float`, defaults to 1e-5): Epsilon factor.
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
+ norm_type (`str`, defaults to `"layer_norm"`):
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
+ """
+
def __init__(
self,
embedding_dim: int,
@@ -358,20 +369,21 @@ def __init__(
out_dim: Optional[int] = None,
):
super().__init__()
+
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
+
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
- # linear_2
+
+ self.linear_2 = None
if out_dim is not None:
- self.linear_2 = nn.Linear(
- embedding_dim,
- out_dim,
- bias=bias,
- )
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
def forward(
self,
@@ -464,6 +476,17 @@ def forward(
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
+ r"""
+ LayerNorm with the bias parameter.
+
+ Args:
+ dim (`int`): Dimensionality to use for the parameters.
+ eps (`float`, defaults to 1e-5): Epsilon factor.
+ elementwise_affine (`bool`, defaults to `True`):
+ Boolean flag to denote if affine transformation should be applied.
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
+ """
+
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
@@ -486,6 +509,68 @@ def forward(self, input):
class RMSNorm(nn.Module):
+ r"""
+ RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
+
+ Args:
+ dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
+ eps (`float`): Small value to use when calculating the reciprocal of the square-root.
+ elementwise_affine (`bool`, defaults to `True`):
+ Boolean flag to denote if affine transformation should be applied.
+ bias (`bool`, defaults to False): If also training the `bias` param.
+ """
+
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
+ super().__init__()
+
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+
+ if isinstance(dim, numbers.Integral):
+ dim = (dim,)
+
+ self.dim = torch.Size(dim)
+
+ self.weight = None
+ self.bias = None
+
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(dim))
+
+ def forward(self, hidden_states):
+ if is_torch_npu_available():
+ import torch_npu
+
+ if self.weight is not None:
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
+ if self.bias is not None:
+ hidden_states = hidden_states + self.bias
+ else:
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+
+ if self.weight is not None:
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ hidden_states = hidden_states * self.weight
+ if self.bias is not None:
+ hidden_states = hidden_states + self.bias
+ else:
+ hidden_states = hidden_states.to(input_dtype)
+
+ return hidden_states
+
+
+# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
+# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
+class MochiRMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
@@ -507,17 +592,20 @@ def forward(self, hidden_states):
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
- # convert into half-precision if necessary
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
- else:
- hidden_states = hidden_states.to(input_dtype)
+ hidden_states = hidden_states.to(input_dtype)
return hidden_states
class GlobalResponseNorm(nn.Module):
+ r"""
+ Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).
+
+ Args:
+ dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
+ """
+
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
super().__init__()
@@ -528,3 +616,33 @@ def forward(self, x):
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * nx) + self.beta + x
+
+
+class LpNorm(nn.Module):
+ def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
+ super().__init__()
+
+ self.p = p
+ self.dim = dim
+ self.eps = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
+
+
+def get_normalization(
+ norm_type: str = "batch_norm",
+ num_features: Optional[int] = None,
+ eps: float = 1e-5,
+ elementwise_affine: bool = True,
+ bias: bool = True,
+) -> nn.Module:
+ if norm_type == "rms_norm":
+ norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
+ elif norm_type == "layer_norm":
+ norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
+ elif norm_type == "batch_norm":
+ norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
+ else:
+ raise ValueError(f"{norm_type=} is not supported.")
+ return norm
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 00b55cd9c9d6..260b4b8929b0 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -366,7 +366,7 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
- input_tensor = self.conv_shortcut(input_tensor)
+ input_tensor = self.conv_shortcut(input_tensor.contiguous())
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
old mode 100644
new mode 100755
index 58787c079ea8..5392935da02b
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -4,6 +4,7 @@
if is_torch_available():
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
+ from .consisid_transformer_3d import ConsisIDTransformer3DModel
from .dit_transformer_2d import DiTTransformer2DModel
from .dual_transformer_2d import DualTransformer2DModel
from .hunyuan_transformer_2d import HunyuanDiT2DModel
@@ -11,10 +12,20 @@
from .lumina_nextdit2d import LuminaNextDiT2DModel
from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer
+ from .sana_transformer import SanaTransformer2DModel
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
+ from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
+ from .transformer_cogview4 import CogView4Transformer2DModel
+ from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
+ from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
+ from .transformer_ltx import LTXVideoTransformer3DModel
+ from .transformer_lumina2 import Lumina2Transformer2DModel
+ from .transformer_mochi import MochiTransformer3DModel
+ from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
+ from .transformer_wan import WanTransformer3DModel
diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py
index ad64df0c0790..4938ed23c506 100644
--- a/src/diffusers/models/transformers/auraflow_transformer_2d.py
+++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py
@@ -13,14 +13,15 @@
# limitations under the License.
-from typing import Any, Dict, Union
+from typing import Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...utils import is_torch_version, logging
+from ...loaders import FromOriginalModelMixin
+from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Attention,
@@ -253,7 +254,7 @@ def forward(
return encoder_hidden_states, hidden_states
-class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
+class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
@@ -275,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
"""
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True
@register_to_config
@@ -442,10 +444,6 @@ def unfuse_qkv_projections(self):
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -466,24 +464,12 @@ def forward(
# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
encoder_hidden_states,
temb,
- **ckpt_kwargs,
)
else:
@@ -497,23 +483,11 @@ def custom_forward(*inputs):
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- combined_hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ combined_hidden_states = self._gradient_checkpointing_func(
+ block,
combined_hidden_states,
temb,
- **ckpt_kwargs,
)
else:
diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
index 821da6d032d5..6b4f38dc04a1 100644
--- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py
+++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
@@ -20,10 +20,11 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -120,8 +121,10 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
+ attention_kwargs = attention_kwargs or {}
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
@@ -133,6 +136,7 @@ def forward(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
+ **attention_kwargs,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
@@ -153,7 +157,7 @@ def forward(
return hidden_states, encoder_hidden_states
-class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
@@ -170,6 +174,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
+ ofs_embed_dim (`int`, defaults to `512`):
+ Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
@@ -177,7 +183,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
- Whether or not to use bias in the attention projection layers.
+ Whether to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
@@ -198,7 +204,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
- Whether or not to use elementwise affine in normalization layers.
+ Whether to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
@@ -207,7 +213,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
"""
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
_supports_gradient_checkpointing = True
+ _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
@register_to_config
def __init__(
@@ -219,6 +227,7 @@ def __init__(
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
+ ofs_embed_dim: Optional[int] = None,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
@@ -227,6 +236,7 @@ def __init__(
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
+ patch_size_t: Optional[int] = None,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
@@ -237,6 +247,7 @@ def __init__(
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
+ patch_bias: bool = True,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@@ -251,10 +262,11 @@ def __init__(
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
+ patch_size_t=patch_size_t,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
- bias=True,
+ bias=patch_bias,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
@@ -267,10 +279,19 @@ def __init__(
)
self.embedding_dropout = nn.Dropout(dropout)
- # 2. Time embeddings
+ # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+ self.ofs_proj = None
+ self.ofs_embedding = None
+ if ofs_embed_dim:
+ self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
+ self.ofs_embedding = TimestepEmbedding(
+ ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
+ ) # same as time embeddings, for ofs
+
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
@@ -298,12 +319,17 @@ def __init__(
norm_eps=norm_eps,
chunk_dim=1,
)
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
- self.gradient_checkpointing = False
+ if patch_size_t is None:
+ # For CogVideox 1.0
+ output_dim = patch_size * patch_size * out_channels
+ else:
+ # For CogVideoX 1.5
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
+
+ self.proj_out = nn.Linear(inner_dim, output_dim)
- def _set_gradient_checkpointing(self, module, value=False):
- self.gradient_checkpointing = value
+ self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
@@ -411,6 +437,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
@@ -442,6 +469,12 @@ def forward(
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
+ if self.ofs_embedding is not None:
+ ofs_emb = self.ofs_proj(ofs)
+ ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
+ ofs_emb = self.ofs_embedding(ofs_emb)
+ emb = emb + ofs_emb
+
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
@@ -452,22 +485,14 @@ def forward(
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
- **ckpt_kwargs,
+ attention_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
@@ -475,28 +500,27 @@ def custom_forward(*inputs):
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
)
- if not self.config.use_rotary_positional_embeddings:
- # CogVideoX-2B
- hidden_states = self.norm_final(hidden_states)
- else:
- # CogVideoX-5B
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
- hidden_states = self.norm_final(hidden_states)
- hidden_states = hidden_states[:, text_seq_length:]
+ hidden_states = self.norm_final(hidden_states)
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
- # Note: we use `-1` instead of `channels`:
- # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
- # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
- output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+ p_t = self.config.patch_size_t
+
+ if p_t is None:
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+ else:
+ output = hidden_states.reshape(
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
+ )
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py
new file mode 100644
index 000000000000..f312553e4c05
--- /dev/null
+++ b/src/diffusers/models/transformers/consisid_transformer_3d.py
@@ -0,0 +1,789 @@
+# Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import Attention, FeedForward
+from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
+from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, dim: int, dim_head: int = 64, heads: int = 8, kv_dim: Optional[int] = None):
+ super().__init__()
+
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, image_embeds: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
+ # Apply normalization
+ image_embeds = self.norm1(image_embeds)
+ latents = self.norm2(latents)
+
+ batch_size, seq_len, _ = latents.shape # Get batch size and sequence length
+
+ # Compute query, key, and value matrices
+ query = self.to_q(latents)
+ kv_input = torch.cat((image_embeds, latents), dim=-2)
+ key, value = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ # Reshape the tensors for multi-head attention
+ query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
+ key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
+ value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ output = weight @ value
+
+ # Reshape and return the final output
+ output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
+
+ return self.to_out(output)
+
+
+class LocalFacialExtractor(nn.Module):
+ def __init__(
+ self,
+ id_dim: int = 1280,
+ vit_dim: int = 1024,
+ depth: int = 10,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_id_token: int = 5,
+ num_queries: int = 32,
+ output_dim: int = 2048,
+ ff_mult: int = 4,
+ num_scale: int = 5,
+ ):
+ super().__init__()
+
+ # Storing identity token and query information
+ self.num_id_token = num_id_token
+ self.vit_dim = vit_dim
+ self.num_queries = num_queries
+ assert depth % num_scale == 0
+ self.depth = depth // num_scale
+ self.num_scale = num_scale
+ scale = vit_dim**-0.5
+
+ # Learnable latent query embeddings
+ self.latents = nn.Parameter(torch.randn(1, num_queries, vit_dim) * scale)
+ # Projection layer to map the latent output to the desired dimension
+ self.proj_out = nn.Parameter(scale * torch.randn(vit_dim, output_dim))
+
+ # Attention and ConsisIDFeedForward layer stack
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
+ nn.Sequential(
+ nn.LayerNorm(vit_dim),
+ nn.Linear(vit_dim, vit_dim * ff_mult, bias=False),
+ nn.GELU(),
+ nn.Linear(vit_dim * ff_mult, vit_dim, bias=False),
+ ), # ConsisIDFeedForward layer
+ ]
+ )
+ )
+
+ # Mappings for each of the 5 different ViT features
+ for i in range(num_scale):
+ setattr(
+ self,
+ f"mapping_{i}",
+ nn.Sequential(
+ nn.Linear(vit_dim, vit_dim),
+ nn.LayerNorm(vit_dim),
+ nn.LeakyReLU(),
+ nn.Linear(vit_dim, vit_dim),
+ nn.LayerNorm(vit_dim),
+ nn.LeakyReLU(),
+ nn.Linear(vit_dim, vit_dim),
+ ),
+ )
+
+ # Mapping for identity embedding vectors
+ self.id_embedding_mapping = nn.Sequential(
+ nn.Linear(id_dim, vit_dim),
+ nn.LayerNorm(vit_dim),
+ nn.LeakyReLU(),
+ nn.Linear(vit_dim, vit_dim),
+ nn.LayerNorm(vit_dim),
+ nn.LeakyReLU(),
+ nn.Linear(vit_dim, vit_dim * num_id_token),
+ )
+
+ def forward(self, id_embeds: torch.Tensor, vit_hidden_states: List[torch.Tensor]) -> torch.Tensor:
+ # Repeat latent queries for the batch size
+ latents = self.latents.repeat(id_embeds.size(0), 1, 1)
+
+ # Map the identity embedding to tokens
+ id_embeds = self.id_embedding_mapping(id_embeds)
+ id_embeds = id_embeds.reshape(-1, self.num_id_token, self.vit_dim)
+
+ # Concatenate identity tokens with the latent queries
+ latents = torch.cat((latents, id_embeds), dim=1)
+
+ # Process each of the num_scale visual feature inputs
+ for i in range(self.num_scale):
+ vit_feature = getattr(self, f"mapping_{i}")(vit_hidden_states[i])
+ ctx_feature = torch.cat((id_embeds, vit_feature), dim=1)
+
+ # Pass through the PerceiverAttention and ConsisIDFeedForward layers
+ for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
+ latents = attn(ctx_feature, latents) + latents
+ latents = ff(latents) + latents
+
+ # Retain only the query latents
+ latents = latents[:, : self.num_queries]
+ # Project the latents to the output dimension
+ latents = latents @ self.proj_out
+ return latents
+
+
+class PerceiverCrossAttention(nn.Module):
+ def __init__(self, dim: int = 3072, dim_head: int = 128, heads: int = 16, kv_dim: int = 2048):
+ super().__init__()
+
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ # Layer normalization to stabilize training
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ # Linear transformations to produce queries, keys, and values
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, image_embeds: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
+ # Apply layer normalization to the input image and latent features
+ image_embeds = self.norm1(image_embeds)
+ hidden_states = self.norm2(hidden_states)
+
+ batch_size, seq_len, _ = hidden_states.shape
+
+ # Compute queries, keys, and values
+ query = self.to_q(hidden_states)
+ key, value = self.to_kv(image_embeds).chunk(2, dim=-1)
+
+ # Reshape tensors to split into attention heads
+ query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
+ key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
+ value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
+
+ # Compute attention weights
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable scaling than post-division
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+
+ # Compute the output via weighted combination of values
+ out = weight @ value
+
+ # Reshape and permute to prepare for final linear transformation
+ out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
+
+ return self.to_out(out)
+
+
+@maybe_allow_in_graph
+class ConsisIDBlock(nn.Module):
+ r"""
+ Transformer block used in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because ConsisID processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ is_train_face (`bool`, defaults to `False`):
+ Whether to use enable the identity-preserving module during the training process. When set to `True`, the
+ model will focus on identity-preserving tasks.
+ is_kps (`bool`, defaults to `False`):
+ Whether to enable keypoint for global facial extractor. If `True`, keypoints will be in the model.
+ cross_attn_interval (`int`, defaults to `2`):
+ The interval between cross-attention layers in the Transformer architecture. A larger value may reduce the
+ frequency of cross-attention computations, which can help reduce computational overhead.
+ cross_attn_dim_head (`int`, optional, defaults to `128`):
+ The dimensionality of each attention head in the cross-attention layers of the Transformer architecture. A
+ larger value increases the capacity to attend to more complex patterns, but also increases memory and
+ computation costs.
+ cross_attn_num_heads (`int`, optional, defaults to `16`):
+ The number of attention heads in the cross-attention layers. More heads allow for more parallel attention
+ mechanisms, capturing diverse relationships between different components of the input, but can also
+ increase computational requirements.
+ LFE_id_dim (`int`, optional, defaults to `1280`):
+ The dimensionality of the identity vector used in the Local Facial Extractor (LFE). This vector represents
+ the identity features of a face, which are important for tasks like face recognition and identity
+ preservation across different frames.
+ LFE_vit_dim (`int`, optional, defaults to `1024`):
+ The dimension of the vision transformer (ViT) output used in the Local Facial Extractor (LFE). This value
+ dictates the size of the transformer-generated feature vectors that will be processed for facial feature
+ extraction.
+ LFE_depth (`int`, optional, defaults to `10`):
+ The number of layers in the Local Facial Extractor (LFE). Increasing the depth allows the model to capture
+ more complex representations of facial features, but also increases the computational load.
+ LFE_dim_head (`int`, optional, defaults to `64`):
+ The dimensionality of each attention head in the Local Facial Extractor (LFE). This parameter affects how
+ finely the model can process and focus on different parts of the facial features during the extraction
+ process.
+ LFE_num_heads (`int`, optional, defaults to `16`):
+ The number of attention heads in the Local Facial Extractor (LFE). More heads can improve the model's
+ ability to capture diverse facial features, but at the cost of increased computational complexity.
+ LFE_num_id_token (`int`, optional, defaults to `5`):
+ The number of identity tokens used in the Local Facial Extractor (LFE). This defines how many
+ identity-related tokens the model will process to ensure face identity preservation during feature
+ extraction.
+ LFE_num_querie (`int`, optional, defaults to `32`):
+ The number of query tokens used in the Local Facial Extractor (LFE). These tokens are used to capture
+ high-frequency face-related information that aids in accurate facial feature extraction.
+ LFE_output_dim (`int`, optional, defaults to `2048`):
+ The output dimension of the Local Facial Extractor (LFE). This dimension determines the size of the feature
+ vectors produced by the LFE module, which will be used for subsequent tasks such as face recognition or
+ tracking.
+ LFE_ff_mult (`int`, optional, defaults to `4`):
+ The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial
+ Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature
+ transformations, but also increases the computation and memory requirements.
+ LFE_num_scale (`int`, optional, defaults to `5`):
+ The number of different scales visual feature. A higher value increases the model's capacity to learn more
+ complex facial feature transformations, but also increases the computation and memory requirements.
+ local_face_scale (`float`, defaults to `1.0`):
+ A scaling factor used to adjust the importance of local facial features in the model. This can influence
+ how strongly the model focuses on high frequency face-related content.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ is_train_face: bool = False,
+ is_kps: bool = False,
+ cross_attn_interval: int = 2,
+ cross_attn_dim_head: int = 128,
+ cross_attn_num_heads: int = 16,
+ LFE_id_dim: int = 1280,
+ LFE_vit_dim: int = 1024,
+ LFE_depth: int = 10,
+ LFE_dim_head: int = 64,
+ LFE_num_heads: int = 16,
+ LFE_num_id_token: int = 5,
+ LFE_num_querie: int = 32,
+ LFE_output_dim: int = 2048,
+ LFE_ff_mult: int = 4,
+ LFE_num_scale: int = 5,
+ local_face_scale: float = 1.0,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no ConsisID checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ ConsisIDBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 4. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.is_train_face = is_train_face
+ self.is_kps = is_kps
+
+ # 5. Define identity-preserving config
+ if is_train_face:
+ # LFE configs
+ self.LFE_id_dim = LFE_id_dim
+ self.LFE_vit_dim = LFE_vit_dim
+ self.LFE_depth = LFE_depth
+ self.LFE_dim_head = LFE_dim_head
+ self.LFE_num_heads = LFE_num_heads
+ self.LFE_num_id_token = LFE_num_id_token
+ self.LFE_num_querie = LFE_num_querie
+ self.LFE_output_dim = LFE_output_dim
+ self.LFE_ff_mult = LFE_ff_mult
+ self.LFE_num_scale = LFE_num_scale
+ # cross configs
+ self.inner_dim = inner_dim
+ self.cross_attn_interval = cross_attn_interval
+ self.num_cross_attn = num_layers // cross_attn_interval
+ self.cross_attn_dim_head = cross_attn_dim_head
+ self.cross_attn_num_heads = cross_attn_num_heads
+ self.cross_attn_kv_dim = int(self.inner_dim / 3 * 2)
+ self.local_face_scale = local_face_scale
+ # face modules
+ self._init_face_inputs()
+
+ self.gradient_checkpointing = False
+
+ def _init_face_inputs(self):
+ self.local_facial_extractor = LocalFacialExtractor(
+ id_dim=self.LFE_id_dim,
+ vit_dim=self.LFE_vit_dim,
+ depth=self.LFE_depth,
+ dim_head=self.LFE_dim_head,
+ heads=self.LFE_num_heads,
+ num_id_token=self.LFE_num_id_token,
+ num_queries=self.LFE_num_querie,
+ output_dim=self.LFE_output_dim,
+ ff_mult=self.LFE_ff_mult,
+ num_scale=self.LFE_num_scale,
+ )
+ self.perceiver_cross_attention = nn.ModuleList(
+ [
+ PerceiverCrossAttention(
+ dim=self.inner_dim,
+ dim_head=self.cross_attn_dim_head,
+ heads=self.cross_attn_num_heads,
+ kv_dim=self.cross_attn_kv_dim,
+ )
+ for _ in range(self.num_cross_attn)
+ ]
+ )
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ id_cond: Optional[torch.Tensor] = None,
+ id_vit_hidden: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # fuse clip and insightface
+ valid_face_emb = None
+ if self.is_train_face:
+ id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype)
+ id_vit_hidden = [
+ tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden
+ ]
+ valid_face_emb = self.local_facial_extractor(
+ id_cond, id_vit_hidden
+ ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ # torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90])
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072])
+ hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072])
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
+ hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072])
+
+ # 3. Transformer blocks
+ ca_idx = 0
+ for i, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ if self.is_train_face:
+ if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
+ hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](
+ valid_face_emb, hidden_states
+ ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
+ ca_idx += 1
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ # Note: we use `-1` instead of `channels`:
+ # - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels)
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py
index 9f8957737dbc..cdc0738050e4 100644
--- a/src/diffusers/models/transformers/dit_transformer_2d.py
+++ b/src/diffusers/models/transformers/dit_transformer_2d.py
@@ -18,7 +18,7 @@
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...utils import is_torch_version, logging
+from ...utils import logging
from ..attention import BasicTransformerBlock
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -64,7 +64,9 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
A small constant added to the denominator in normalization layers to prevent division by zero.
"""
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True
+ _supports_group_offloading = False
@register_to_config
def __init__(
@@ -143,10 +145,6 @@ def __init__(
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -184,20 +182,9 @@ def forward(
# 2. Blocks
for block in self.transformer_blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
None,
None,
@@ -205,7 +192,6 @@ def custom_forward(*inputs):
timestep,
cross_attention_kwargs,
class_labels,
- **ckpt_kwargs,
)
else:
hidden_states = block(
diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
index 7f3dab220aaa..550cc6d9d1e5 100644
--- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py
+++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
@@ -244,6 +244,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
+ _supports_group_offloading = False
+
@register_to_config
def __init__(
self,
@@ -277,9 +280,7 @@ def __init__(
act_fn="silu_fp32",
)
- self.text_embedding_padding = nn.Parameter(
- torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
- )
+ self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
self.pos_embed = PatchEmbed(
height=sample_size,
diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py
index 71d19216e5ff..132c258455ea 100644
--- a/src/diffusers/models/transformers/latte_transformer_3d.py
+++ b/src/diffusers/models/transformers/latte_transformer_3d.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import Optional
import torch
@@ -19,13 +20,14 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock
+from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
-class LatteTransformer3DModel(ModelMixin, ConfigMixin):
+class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
"""
@@ -65,6 +67,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
The number of frames in the video-like data.
"""
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+
@register_to_config
def __init__(
self,
@@ -156,15 +160,12 @@ def __init__(
# define temporal positional embedding
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
- inner_dim, torch.arange(0, video_length).unsqueeze(1)
+ inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
) # 1152 hidden size
- self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
+ self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
self.gradient_checkpointing = False
- def _set_gradient_checkpointing(self, module, value=False):
- self.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -226,20 +227,24 @@ def forward(
# Prepare text embeddings for spatial block
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
- encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
- -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
- )
+ encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
+ num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
+ ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
# Prepare timesteps for spatial and temporal block
- timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
- timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
+ timestep_spatial = timestep.repeat_interleave(
+ num_frame, dim=0, output_size=timestep.shape[0] * num_frame
+ ).view(-1, timestep.shape[-1])
+ timestep_temp = timestep.repeat_interleave(
+ num_patches, dim=0, output_size=timestep.shape[0] * num_patches
+ ).view(-1, timestep.shape[-1])
# Spatial and temporal transformer blocks
for i, (spatial_block, temp_block) in enumerate(
zip(self.transformer_blocks, self.temporal_transformer_blocks)
):
- if self.training and self.gradient_checkpointing:
- hidden_states = torch.utils.checkpoint.checkpoint(
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
spatial_block,
hidden_states,
None, # attention_mask
@@ -248,7 +253,6 @@ def forward(
timestep_spatial,
None, # cross_attention_kwargs
None, # class_labels
- use_reentrant=False,
)
else:
hidden_states = spatial_block(
@@ -269,10 +273,10 @@ def forward(
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
if i == 0 and num_frame > 1:
- hidden_states = hidden_states + self.temp_pos_embed
+ hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
- if self.training and self.gradient_checkpointing:
- hidden_states = torch.utils.checkpoint.checkpoint(
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
temp_block,
hidden_states,
None, # attention_mask
@@ -281,7 +285,6 @@ def forward(
timestep_temp,
None, # cross_attention_kwargs
None, # class_labels
- use_reentrant=False,
)
else:
hidden_states = temp_block(
@@ -300,7 +303,9 @@ def forward(
).permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
- embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
+ embedded_timestep = embedded_timestep.repeat_interleave(
+ num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
+ ).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py
index d4f5b4658542..320950866c4a 100644
--- a/src/diffusers/models/transformers/lumina_nextdit2d.py
+++ b/src/diffusers/models/transformers/lumina_nextdit2d.py
@@ -98,7 +98,7 @@ def __init__(
self.feed_forward = LuminaFeedForward(
dim=dim,
- inner_dim=4 * dim,
+ inner_dim=int(4 * 2 * dim / 3),
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
@@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
overall scale of the model's operations.
"""
+ _skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"]
+
@register_to_config
def __init__(
self,
diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py
index 1e5cd5794517..8e290074a018 100644
--- a/src/diffusers/models/transformers/pixart_transformer_2d.py
+++ b/src/diffusers/models/transformers/pixart_transformer_2d.py
@@ -17,7 +17,7 @@
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...utils import is_torch_version, logging
+from ...utils import logging
from ..attention import BasicTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
@register_to_config
def __init__(
@@ -183,10 +184,6 @@ def __init__(
in_features=self.config.caption_channels, hidden_size=self.inner_dim
)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -386,20 +383,9 @@ def forward(
# 2. Blocks
for block in self.transformer_blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -407,7 +393,6 @@ def custom_forward(*inputs):
timestep,
cross_attention_kwargs,
None,
- **ckpt_kwargs,
)
else:
hidden_states = block(
diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py
index fdb67384ff5e..24d4e4d3d76f 100644
--- a/src/diffusers/models/transformers/prior_transformer.py
+++ b/src/diffusers/models/transformers/prior_transformer.py
@@ -353,7 +353,11 @@ def forward(
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
- attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
+ attention_mask = attention_mask.repeat_interleave(
+ self.config.num_attention_heads,
+ dim=0,
+ output_size=attention_mask.shape[0] * self.config.num_attention_heads,
+ )
if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py
new file mode 100644
index 000000000000..48b731406191
--- /dev/null
+++ b/src/diffusers/models/transformers/sana_transformer.py
@@ -0,0 +1,592 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention_processor import (
+ Attention,
+ AttentionProcessor,
+ SanaLinearAttnProcessor2_0,
+)
+from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormSingle, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class GLUMBConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expand_ratio: float = 4,
+ norm_type: Optional[str] = None,
+ residual_connection: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_channels = int(expand_ratio * in_channels)
+ self.norm_type = norm_type
+ self.residual_connection = residual_connection
+
+ self.nonlinearity = nn.SiLU()
+ self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
+ self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
+ self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
+
+ self.norm = None
+ if norm_type == "rms_norm":
+ self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.residual_connection:
+ residual = hidden_states
+
+ hidden_states = self.conv_inverted(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv_depth(hidden_states)
+ hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
+ hidden_states = hidden_states * self.nonlinearity(gate)
+
+ hidden_states = self.conv_point(hidden_states)
+
+ if self.norm_type == "rms_norm":
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.residual_connection:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class SanaModulatedNorm(nn.Module):
+ def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
+ ) -> torch.Tensor:
+ hidden_states = self.norm(hidden_states)
+ shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+ return hidden_states
+
+
+class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
+ def __init__(self, embedding_dim):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ guidance_proj = self.guidance_condition_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
+ conditioning = timesteps_emb + guidance_emb
+
+ return self.linear(self.silu(conditioning)), conditioning
+
+
+class SanaAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SanaTransformerBlock(nn.Module):
+ r"""
+ Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
+ """
+
+ def __init__(
+ self,
+ dim: int = 2240,
+ num_attention_heads: int = 70,
+ attention_head_dim: int = 32,
+ dropout: float = 0.0,
+ num_cross_attention_heads: Optional[int] = 20,
+ cross_attention_head_dim: Optional[int] = 112,
+ cross_attention_dim: Optional[int] = 2240,
+ attention_bias: bool = True,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ attention_out_bias: bool = True,
+ mlp_ratio: float = 2.5,
+ qk_norm: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ kv_heads=num_attention_heads if qk_norm is not None else None,
+ qk_norm=qk_norm,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ processor=SanaLinearAttnProcessor2_0(),
+ )
+
+ # 2. Cross Attention
+ if cross_attention_dim is not None:
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.attn2 = Attention(
+ query_dim=dim,
+ qk_norm=qk_norm,
+ kv_heads=num_cross_attention_heads if qk_norm is not None else None,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_cross_attention_heads,
+ dim_head=cross_attention_head_dim,
+ dropout=dropout,
+ bias=True,
+ out_bias=attention_out_bias,
+ processor=SanaAttnProcessor2_0(),
+ )
+
+ # 3. Feed-forward
+ self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ height: int = None,
+ width: int = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ # 1. Modulation
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+
+ # 2. Self Attention
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
+
+ attn_output = self.attn1(norm_hidden_states)
+ hidden_states = hidden_states + gate_msa * attn_output
+
+ # 3. Cross Attention
+ if self.attn2 is not None:
+ attn_output = self.attn2(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
+ hidden_states = hidden_states + gate_mlp * ff_output
+
+ return hidden_states
+
+
+class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ r"""
+ A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
+
+ Args:
+ in_channels (`int`, defaults to `32`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `32`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `70`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `32`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of Transformer blocks to use.
+ num_cross_attention_heads (`int`, *optional*, defaults to `20`):
+ The number of heads to use for cross-attention.
+ cross_attention_head_dim (`int`, *optional*, defaults to `112`):
+ The number of channels in each head for cross-attention.
+ cross_attention_dim (`int`, *optional*, defaults to `2240`):
+ The number of channels in the cross-attention output.
+ caption_channels (`int`, defaults to `2304`):
+ The number of channels in the caption embeddings.
+ mlp_ratio (`float`, defaults to `2.5`):
+ The expansion ratio to use in the GLUMBConv layer.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability.
+ attention_bias (`bool`, defaults to `False`):
+ Whether to use bias in the attention layer.
+ sample_size (`int`, defaults to `32`):
+ The base size of the input latent.
+ patch_size (`int`, defaults to `1`):
+ The size of the patches to use in the patch embedding layer.
+ norm_elementwise_affine (`bool`, defaults to `False`):
+ Whether to use elementwise affinity in the normalization layer.
+ norm_eps (`float`, defaults to `1e-6`):
+ The epsilon value for the normalization layer.
+ qk_norm (`str`, *optional*, defaults to `None`):
+ The normalization to use for the query and key.
+ timestep_scale (`float`, defaults to `1.0`):
+ The scale to use for the timesteps.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 32,
+ out_channels: Optional[int] = 32,
+ num_attention_heads: int = 70,
+ attention_head_dim: int = 32,
+ num_layers: int = 20,
+ num_cross_attention_heads: Optional[int] = 20,
+ cross_attention_head_dim: Optional[int] = 112,
+ cross_attention_dim: Optional[int] = 2240,
+ caption_channels: int = 2304,
+ mlp_ratio: float = 2.5,
+ dropout: float = 0.0,
+ attention_bias: bool = False,
+ sample_size: int = 32,
+ patch_size: int = 1,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ interpolation_scale: Optional[int] = None,
+ guidance_embeds: bool = False,
+ guidance_embeds_scale: float = 0.1,
+ qk_norm: Optional[str] = None,
+ timestep_scale: float = 1.0,
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Patch Embedding
+ self.patch_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
+ pos_embed_type="sincos" if interpolation_scale is not None else None,
+ )
+
+ # 2. Additional condition embeddings
+ if guidance_embeds:
+ self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
+ else:
+ self.time_embed = AdaLayerNormSingle(inner_dim)
+
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+ self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
+
+ # 3. Transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ SanaTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ num_cross_attention_heads=num_cross_attention_heads,
+ cross_attention_head_dim=cross_attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output blocks
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ guidance: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch_size, num_channels, height, width = hidden_states.shape
+ p = self.config.patch_size
+ post_patch_height, post_patch_width = height // p, width // p
+
+ hidden_states = self.patch_embed(hidden_states)
+
+ if guidance is not None:
+ timestep, embedded_timestep = self.time_embed(
+ timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ timestep, embedded_timestep = self.time_embed(
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ encoder_hidden_states = self.caption_norm(encoder_hidden_states)
+
+ # 2. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.transformer_blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ post_patch_height,
+ post_patch_width,
+ )
+
+ else:
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ post_patch_height,
+ post_patch_width,
+ )
+
+ # 3. Normalization
+ hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
+
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
+ )
+ hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
+ output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py
index e3462b51a412..d81b6447adb0 100644
--- a/src/diffusers/models/transformers/stable_audio_transformer.py
+++ b/src/diffusers/models/transformers/stable_audio_transformer.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Dict, Optional, Union
import numpy as np
import torch
@@ -29,7 +29,7 @@
)
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_2d import Transformer2DModelOutput
-from ...utils import is_torch_version, logging
+from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
@@ -211,6 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"]
@register_to_config
def __init__(
@@ -345,10 +346,6 @@ def set_default_attn_processor(self):
"""
self.set_attn_processor(StableAudioAttnProcessor2_0())
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -414,26 +411,14 @@ def forward(
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
for block in self.transformer_blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
attention_mask,
cross_attention_hidden_states,
encoder_attention_mask,
rotary_embedding,
- **ckpt_kwargs,
)
else:
diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py
index c7c19e4582c6..a88ee6c9c9b8 100644
--- a/src/diffusers/models/transformers/transformer_2d.py
+++ b/src/diffusers/models/transformers/transformer_2d.py
@@ -18,7 +18,7 @@
from torch import nn
from ...configuration_utils import LegacyConfigMixin, register_to_config
-from ...utils import deprecate, is_torch_version, logging
+from ...utils import deprecate, logging
from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
@@ -66,6 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["latent_image_embedding", "norm"]
@register_to_config
def __init__(
@@ -320,10 +321,6 @@ def _init_patched_inputs(self, norm_type):
in_features=self.caption_channels, hidden_size=self.inner_dim
)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -415,20 +412,9 @@ def forward(
# 2. Blocks
for block in self.transformer_blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -436,7 +422,6 @@ def custom_forward(*inputs):
timestep,
cross_attention_kwargs,
class_labels,
- **ckpt_kwargs,
)
else:
hidden_states = block(
diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py
new file mode 100644
index 000000000000..d5c93409c932
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_allegro.py
@@ -0,0 +1,414 @@
+# Copyright 2024 The RhymesAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import FeedForward
+from ..attention_processor import AllegroAttnProcessor2_0, Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormSingle
+
+
+logger = logging.get_logger(__name__)
+
+
+@maybe_allow_in_graph
+class AllegroTransformerBlock(nn.Module):
+ r"""
+ Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model.
+
+ Args:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ cross_attention_dim (`int`, defaults to `2304`):
+ The dimension of the cross attention features.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ only_cross_attention (`bool`, defaults to `False`):
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ attention_bias: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ processor=AllegroAttnProcessor2_0(),
+ )
+
+ # 2. Cross Attention
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ processor=AllegroAttnProcessor2_0(),
+ )
+
+ # 3. Feed Forward
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ )
+
+ # 4. Scale-shift
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ temb: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb=None,
+ ) -> torch.Tensor:
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 1. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = hidden_states
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ image_rotary_emb=None,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 2. Feed-forward
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ # TODO(aryan): maybe following line is not required
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
+ _supports_gradient_checkpointing = True
+
+ """
+ A 3D Transformer model for video-like data.
+
+ Args:
+ patch_size (`int`, defaults to `2`):
+ The size of spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of temporal patches to use in the patch embedding layer.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `96`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `4`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `4`):
+ The number of channels in the output.
+ num_layers (`int`, defaults to `32`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ cross_attention_dim (`int`, defaults to `2304`):
+ The dimension of the cross attention features.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_height (`int`, defaults to `90`):
+ The height of the input latents.
+ sample_width (`int`, defaults to `160`):
+ The width of the input latents.
+ sample_frames (`int`, defaults to `22`):
+ The number of frames in the input latents.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ norm_elementwise_affine (`bool`, defaults to `False`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-6`):
+ The epsilon value to use in normalization layers.
+ caption_channels (`int`, defaults to `4096`):
+ Number of channels to use for projecting the caption embeddings.
+ interpolation_scale_h (`float`, defaults to `2.0`):
+ Scaling factor to apply in 3D positional embeddings across height dimension.
+ interpolation_scale_w (`float`, defaults to `2.0`):
+ Scaling factor to apply in 3D positional embeddings across width dimension.
+ interpolation_scale_t (`float`, defaults to `2.2`):
+ Scaling factor to apply in 3D positional embeddings across time dimension.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ patch_size_t: int = 1,
+ num_attention_heads: int = 24,
+ attention_head_dim: int = 96,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ num_layers: int = 32,
+ dropout: float = 0.0,
+ cross_attention_dim: int = 2304,
+ attention_bias: bool = True,
+ sample_height: int = 90,
+ sample_width: int = 160,
+ sample_frames: int = 22,
+ activation_fn: str = "gelu-approximate",
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ caption_channels: int = 4096,
+ interpolation_scale_h: float = 2.0,
+ interpolation_scale_w: float = 2.0,
+ interpolation_scale_t: float = 2.2,
+ ):
+ super().__init__()
+
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ interpolation_scale_t = (
+ interpolation_scale_t
+ if interpolation_scale_t is not None
+ else ((sample_frames - 1) // 16 + 1)
+ if sample_frames % 2 == 1
+ else sample_frames // 16
+ )
+ interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
+ interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
+
+ # 1. Patch embedding
+ self.pos_embed = PatchEmbed(
+ height=sample_height,
+ width=sample_width,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_type=None,
+ )
+
+ # 2. Transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ AllegroTransformerBlock(
+ self.inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 3. Output projection & norm
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
+
+ # 4. Timestep embeddings
+ self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
+
+ # 5. Caption projection
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ return_dict: bool = True,
+ ):
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t = self.config.patch_size_t
+ p = self.config.patch_size
+
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None
+ if attention_mask is not None and attention_mask.ndim == 4:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ # b, frame+use_image_num, h, w -> a video with images
+ # b, 1, h, w -> only images
+ attention_mask = attention_mask.to(hidden_states.dtype)
+ attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width]
+
+ if attention_mask.numel() > 0:
+ attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width]
+ attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p))
+ attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1)
+
+ attention_mask = (
+ (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None
+ )
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Timestep embeddings
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ # 2. Patch embeddings
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ hidden_states = self.pos_embed(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
+
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ # TODO(aryan): Implement gradient checkpointing
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ timestep,
+ attention_mask,
+ encoder_attention_mask,
+ image_rotary_emb,
+ )
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=timestep,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 4. Output normalization & projection
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # 5. Unpatchify
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 962cbbff7c1b..da7133791f37 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Any, Dict, Union
+from typing import Dict, Union
import torch
import torch.nn as nn
@@ -27,7 +27,7 @@
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
-from ...utils import is_torch_version, logging
+from ...utils import logging
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
@@ -166,6 +166,8 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
+ _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config
def __init__(
@@ -287,10 +289,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -341,21 +339,12 @@ def forward(
hidden_states = hidden_states[:, text_seq_length:]
for index_block, block in enumerate(self.transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
encoder_hidden_states,
emb,
- **ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
new file mode 100644
index 000000000000..41c4cbbf97c7
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -0,0 +1,462 @@
+# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import FeedForward
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CogView4PatchEmbed(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 16,
+ hidden_size: int = 2560,
+ patch_size: int = 2,
+ text_hidden_size: int = 4096,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, channel, height, width = hidden_states.shape
+ post_patch_height = height // self.patch_size
+ post_patch_width = width // self.patch_size
+
+ hidden_states = hidden_states.reshape(
+ batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
+ )
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
+ hidden_states = self.proj(hidden_states)
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogView4AdaLayerNormZero(nn.Module):
+ def __init__(self, embedding_dim: int, dim: int) -> None:
+ super().__init__()
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
+
+ def forward(
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ norm_hidden_states = self.norm(hidden_states)
+ norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
+
+ emb = self.linear(temb)
+ (
+ shift_msa,
+ c_shift_msa,
+ scale_msa,
+ c_scale_msa,
+ gate_msa,
+ c_gate_msa,
+ shift_mlp,
+ c_shift_mlp,
+ scale_mlp,
+ c_scale_mlp,
+ gate_mlp,
+ c_gate_mlp,
+ ) = emb.chunk(12, dim=1)
+
+ hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
+ encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
+
+ return (
+ hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ encoder_hidden_states,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ )
+
+
+class CogView4AttnProcessor:
+ """
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
+ batch_size, image_seq_length, embed_dim = hidden_states.shape
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # 1. QKV projections
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ # 2. QK normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # 3. Rotational positional embeddings applied to latent stream
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:, :] = apply_rotary_emb(
+ query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
+ )
+ key[:, :, text_seq_length:, :] = apply_rotary_emb(
+ key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
+ )
+
+ # 4. Attention
+ if attention_mask is not None:
+ text_attention_mask = attention_mask.float().to(query.device)
+ actual_text_seq_length = text_attention_mask.size(1)
+ new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
+ new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
+ new_attention_mask = new_attention_mask.unsqueeze(2)
+ attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
+ attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ # 5. Output projection
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class CogView4TransformerBlock(nn.Module):
+ def __init__(
+ self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
+ ) -> None:
+ super().__init__()
+
+ # 1. Attention
+ self.norm1 = CogView4AdaLayerNormZero(time_embed_dim, dim)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ out_dim=dim,
+ bias=True,
+ qk_norm="layer_norm",
+ elementwise_affine=False,
+ eps=1e-5,
+ processor=CogView4AttnProcessor(),
+ )
+
+ # 2. Feedforward
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ # 1. Timestep conditioning
+ (
+ norm_hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ norm_encoder_hidden_states,
+ c_gate_msa,
+ c_shift_mlp,
+ c_scale_mlp,
+ c_gate_mlp,
+ ) = self.norm1(hidden_states, encoder_hidden_states, temb)
+
+ # 2. Attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=attention_mask,
+ **kwargs,
+ )
+ hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
+
+ # 3. Feedforward
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
+ 1 + c_scale_mlp.unsqueeze(1)
+ ) + c_shift_mlp.unsqueeze(1)
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output_context = self.ff(norm_encoder_hidden_states)
+ hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogView4RotaryPosEmbed(nn.Module):
+ def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
+ super().__init__()
+
+ self.dim = dim
+ self.patch_size = patch_size
+ self.rope_axes_dim = rope_axes_dim
+ self.theta = theta
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size, num_channels, height, width = hidden_states.shape
+ height, width = height // self.patch_size, width // self.patch_size
+
+ dim_h, dim_w = self.dim // 2, self.dim // 2
+ h_inv_freq = 1.0 / (
+ self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
+ )
+ w_inv_freq = 1.0 / (
+ self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
+ )
+ h_seq = torch.arange(self.rope_axes_dim[0])
+ w_seq = torch.arange(self.rope_axes_dim[1])
+ freqs_h = torch.outer(h_seq, h_inv_freq)
+ freqs_w = torch.outer(w_seq, w_inv_freq)
+
+ h_idx = torch.arange(height, device=freqs_h.device)
+ w_idx = torch.arange(width, device=freqs_w.device)
+ inner_h_idx = h_idx * self.rope_axes_dim[0] // height
+ inner_w_idx = w_idx * self.rope_axes_dim[1] // width
+
+ freqs_h = freqs_h[inner_h_idx]
+ freqs_w = freqs_w[inner_w_idx]
+
+ # Create position matrices for height and width
+ # [height, 1, dim//4] and [1, width, dim//4]
+ freqs_h = freqs_h.unsqueeze(1)
+ freqs_w = freqs_w.unsqueeze(0)
+ # Broadcast freqs_h and freqs_w to [height, width, dim//4]
+ freqs_h = freqs_h.expand(height, width, -1)
+ freqs_w = freqs_w.expand(height, width, -1)
+
+ # Concatenate along last dimension to get [height, width, dim//2]
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1)
+ freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
+ freqs = freqs.reshape(height * width, -1)
+ return (freqs.cos(), freqs.sin())
+
+
+class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
+ r"""
+ Args:
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ attention_head_dim (`int`, defaults to `40`):
+ The number of channels in each head.
+ num_attention_heads (`int`, defaults to `64`):
+ The number of heads to use for multi-head attention.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ condition_dim (`int`, defaults to `256`):
+ The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
+ crop_coords).
+ pos_embed_max_size (`int`, defaults to `128`):
+ The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
+ to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
+ means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
+ patch_size => 128 * 8 * 2 => 2048`.
+ sample_size (`int`, defaults to `128`):
+ The base resolution of input latents. If height/width is not provided during generation, this value is used
+ to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["CogView4TransformerBlock", "CogView4PatchEmbed", "CogView4PatchEmbed"]
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ num_layers: int = 30,
+ attention_head_dim: int = 40,
+ num_attention_heads: int = 64,
+ text_embed_dim: int = 4096,
+ time_embed_dim: int = 512,
+ condition_dim: int = 256,
+ pos_embed_max_size: int = 128,
+ sample_size: int = 128,
+ rope_axes_dim: Tuple[int, int] = (256, 256),
+ ):
+ super().__init__()
+
+ # CogView4 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
+ # Each of these are sincos embeddings of shape 2 * condition_dim
+ pooled_projection_dim = 3 * 2 * condition_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels
+
+ # 1. RoPE
+ self.rope = CogView4RotaryPosEmbed(attention_head_dim, patch_size, rope_axes_dim, theta=10000.0)
+
+ # 2. Patch & Text-timestep embedding
+ self.patch_embed = CogView4PatchEmbed(in_channels, inner_dim, patch_size, text_embed_dim)
+
+ self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
+ embedding_dim=time_embed_dim,
+ condition_dim=condition_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ timesteps_dim=inner_dim,
+ )
+
+ # 3. Transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output projection
+ self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ original_size: torch.Tensor,
+ target_size: torch.Tensor,
+ crop_coords: torch.Tensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, height, width = hidden_states.shape
+
+ # 1. RoPE
+ image_rotary_emb = self.rope(hidden_states)
+
+ # 2. Patch & Timestep embeddings
+ p = self.config.patch_size
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states)
+
+ temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
+ temb = F.silu(temb)
+
+ # 3. Transformer blocks
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
+ )
+
+ # 4. Output norm & projection
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
+ output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_easyanimate.py b/src/diffusers/models/transformers/transformer_easyanimate.py
new file mode 100755
index 000000000000..545fa29730db
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_easyanimate.py
@@ -0,0 +1,527 @@
+# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import Attention, FeedForward
+from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class EasyAnimateLayerNormZero(nn.Module):
+ def __init__(
+ self,
+ conditioning_dim: int,
+ embedding_dim: int,
+ elementwise_affine: bool = True,
+ eps: float = 1e-5,
+ bias: bool = True,
+ norm_type: str = "fp32_layer_norm",
+ ) -> None:
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
+
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
+ elif norm_type == "fp32_layer_norm":
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
+ hidden_states = self.norm(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale.unsqueeze(1)) + enc_shift.unsqueeze(
+ 1
+ )
+ return hidden_states, encoder_hidden_states, gate, enc_gate
+
+
+class EasyAnimateRotaryPosEmbed(nn.Module):
+ def __init__(self, patch_size: int, rope_dim: List[int]) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.rope_dim = rope_dim
+
+ def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ bs, c, num_frames, grid_height, grid_width = hidden_states.size()
+ grid_height = grid_height // self.patch_size
+ grid_width = grid_width // self.patch_size
+ base_size_width = 90 // self.patch_size
+ base_size_height = 60 // self.patch_size
+
+ grid_crops_coords = self.get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ image_rotary_emb = get_3d_rotary_pos_embed(
+ self.rope_dim,
+ grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=hidden_states.size(2),
+ use_real=True,
+ )
+ return image_rotary_emb
+
+
+class EasyAnimateAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the EasyAnimateTransformer3DModel model.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # 1. QKV projections
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ # 2. QK normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # 3. Encoder condition QKV projection and normalization
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=2)
+ key = torch.cat([encoder_key, key], dim=2)
+ value = torch.cat([encoder_value, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ query[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
+ query[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
+ )
+ if not attn.is_cross_attention:
+ key[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
+ key[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
+ )
+
+ # 5. Attention
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # 6. Output projection
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ if getattr(attn, "to_out", None) is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if getattr(attn, "to_add_out", None) is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+ else:
+ if getattr(attn, "to_out", None) is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
+@maybe_allow_in_graph
+class EasyAnimateTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-6,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ qk_norm: bool = True,
+ after_norm: bool = False,
+ norm_type: str = "fp32_layer_norm",
+ is_mmdit_block: bool = True,
+ ):
+ super().__init__()
+
+ # Attention Part
+ self.norm1 = EasyAnimateLayerNormZero(
+ time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
+ )
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=True,
+ added_proj_bias=True,
+ added_kv_proj_dim=dim if is_mmdit_block else None,
+ context_pre_only=False if is_mmdit_block else None,
+ processor=EasyAnimateAttnProcessor2_0(),
+ )
+
+ # FFN Part
+ self.norm2 = EasyAnimateLayerNormZero(
+ time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
+ )
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ self.txt_ff = None
+ if is_mmdit_block:
+ self.txt_ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ self.norm3 = None
+ if after_norm:
+ self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Attention
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
+
+ # 2. Feed-forward
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+ if self.norm3 is not None:
+ norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
+ if self.txt_ff is not None:
+ norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
+ else:
+ norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
+ else:
+ norm_hidden_states = self.ff(norm_hidden_states)
+ if self.txt_ff is not None:
+ norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
+ else:
+ norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
+ hidden_states = hidden_states + gate_ff.unsqueeze(1) * norm_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff.unsqueeze(1) * norm_encoder_hidden_states
+ return hidden_states, encoder_hidden_states
+
+
+class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
+ """
+ A Transformer model for video-like data in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `48`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ mmdit_layers (`int`, defaults to `1000`):
+ The number of layers of Multi Modal Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use elementwise affine in normalization layers.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_position_encoding_type (`str`, defaults to `3d_rope`):
+ Type of time position encoding.
+ after_norm (`bool`, defaults to `False`):
+ Flag to apply normalization after.
+ resize_inpaint_mask_directly (`bool`, defaults to `True`):
+ Flag to resize inpaint mask directly.
+ enable_text_attention_mask (`bool`, defaults to `True`):
+ Flag to enable text attention mask.
+ add_noise_in_inpaint_model (`bool`, defaults to `False`):
+ Flag to add noise in inpaint model.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["EasyAnimateTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["^proj$", "norm", "^proj_out$"]
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 48,
+ attention_head_dim: int = 64,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ freq_shift: int = 0,
+ num_layers: int = 48,
+ mmdit_layers: int = 48,
+ dropout: float = 0.0,
+ time_embed_dim: int = 512,
+ add_norm_text_encoder: bool = False,
+ text_embed_dim: int = 3584,
+ text_embed_dim_t5: int = None,
+ norm_eps: float = 1e-5,
+ norm_elementwise_affine: bool = True,
+ flip_sin_to_cos: bool = True,
+ time_position_encoding_type: str = "3d_rope",
+ after_norm=False,
+ resize_inpaint_mask_directly: bool = True,
+ enable_text_attention_mask: bool = True,
+ add_noise_in_inpaint_model: bool = True,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Timestep embedding
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+ self.rope_embedding = EasyAnimateRotaryPosEmbed(patch_size, attention_head_dim)
+
+ # 2. Patch embedding
+ self.proj = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
+ )
+
+ # 3. Text refined embedding
+ self.text_proj = None
+ self.text_proj_t5 = None
+ if not add_norm_text_encoder:
+ self.text_proj = nn.Linear(text_embed_dim, inner_dim)
+ if text_embed_dim_t5 is not None:
+ self.text_proj_t5 = nn.Linear(text_embed_dim_t5, inner_dim)
+ else:
+ self.text_proj = nn.Sequential(
+ RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim, inner_dim)
+ )
+ if text_embed_dim_t5 is not None:
+ self.text_proj_t5 = nn.Sequential(
+ RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim_t5, inner_dim)
+ )
+
+ # 4. Transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ EasyAnimateTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ after_norm=after_norm,
+ is_mmdit_block=True if _ < mmdit_layers else False,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 5. Output norm & projection
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ timestep_cond: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_hidden_states_t5: Optional[torch.Tensor] = None,
+ inpaint_latents: Optional[torch.Tensor] = None,
+ control_latents: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
+ batch_size, channels, video_length, height, width = hidden_states.size()
+ p = self.config.patch_size
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ # 1. Time embedding
+ temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
+ temb = self.time_embedding(temb, timestep_cond)
+ image_rotary_emb = self.rope_embedding(hidden_states)
+
+ # 2. Patch embedding
+ if inpaint_latents is not None:
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
+ if control_latents is not None:
+ hidden_states = torch.concat([hidden_states, control_latents], 1)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, F, H, W] -> [BF, C, H, W]
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
+ 0, 2, 1, 3, 4
+ ) # [BF, C, H, W] -> [B, F, C, H, W]
+ hidden_states = hidden_states.flatten(2, 4).transpose(1, 2) # [B, F, C, H, W] -> [B, FHW, C]
+
+ # 3. Text embedding
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
+ if encoder_hidden_states_t5 is not None:
+ encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
+
+ # 4. Transformer blocks
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states, encoder_hidden_states, temb, image_rotary_emb
+ )
+
+ hidden_states = self.norm_final(hidden_states)
+
+ # 5. Output norm & projection
+ hidden_states = self.norm_out(hidden_states, temb=temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 6. Unpatchify
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, video_length, post_patch_height, post_patch_width, channels, p, p)
+ output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 5d39a1bb5391..87537890d246 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -18,21 +18,23 @@
import numpy as np
import torch
import torch.nn as nn
-import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
+ FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
-from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
+from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -42,20 +44,7 @@
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
- r"""
- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
-
- Reference: https://arxiv.org/abs/2403.03206
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
- processing of `context` conditions.
- """
-
- def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
@@ -64,7 +53,16 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
- processor = FluxAttnProcessor2_0()
+ if is_torch_npu_available():
+ deprecation_message = (
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
+ "should be set explicitly using the `set_attn_processor` method."
+ )
+ deprecate("npu_processor", "0.34.0", deprecation_message)
+ processor = FluxAttnProcessor2_0_NPU()
+ else:
+ processor = FluxAttnProcessor2_0()
+
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
@@ -80,11 +78,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
def forward(
self,
- hidden_states: torch.FloatTensor,
- temb: torch.FloatTensor,
- image_rotary_emb=None,
- joint_attention_kwargs=None,
- ):
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -107,32 +105,14 @@ def forward(
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
- r"""
- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
-
- Reference: https://arxiv.org/abs/2403.03206
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
- processing of `context` conditions.
- """
-
- def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
-
self.norm1_context = AdaLayerNormZero(dim)
- if hasattr(F, "scaled_dot_product_attention"):
- processor = FluxAttnProcessor2_0()
- else:
- raise ValueError(
- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
- )
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
@@ -142,7 +122,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no
out_dim=dim,
context_pre_only=False,
bias=True,
- processor=processor,
+ processor=FluxAttnProcessor2_0(),
qk_norm=qk_norm,
eps=eps,
)
@@ -153,18 +133,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = 0
-
def forward(
self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor,
- temb: torch.FloatTensor,
- image_rotary_emb=None,
- joint_attention_kwargs=None,
- ):
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
@@ -172,13 +148,18 @@ def forward(
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
- attn_output, context_attn_output = self.attn(
+ attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
@@ -190,6 +171,8 @@ def forward(
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
@@ -207,32 +190,50 @@ def forward(
return encoder_hidden_states, hidden_states
-class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+class FluxTransformer2DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
- Parameters:
- patch_size (`int`): Patch size to turn the input data into small patches.
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
- num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
- num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
- joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
- guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ Args:
+ patch_size (`int`, defaults to `1`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `64`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `19`):
+ The number of layers of dual stream DiT blocks to use.
+ num_single_layers (`int`, defaults to `38`):
+ The number of layers of single stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of attention heads to use.
+ joint_attention_dim (`int`, defaults to `4096`):
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
+ `encoder_hidden_states`).
+ pooled_projection_dim (`int`, defaults to `768`):
+ The number of dimensions to use for the pooled projection.
+ guidance_embeds (`bool`, defaults to `False`):
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions to use for the rotary positional embeddings.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
+ out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
@@ -243,8 +244,8 @@ def __init__(
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
- self.out_channels = in_channels
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
@@ -252,20 +253,20 @@ def __init__(
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
- embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
- self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
- self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
- num_attention_heads=self.config.num_attention_heads,
- attention_head_dim=self.config.attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
)
- for i in range(self.config.num_layers)
+ for _ in range(num_layers)
]
)
@@ -273,10 +274,10 @@ def __init__(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
- num_attention_heads=self.config.num_attention_heads,
- attention_head_dim=self.config.attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
)
- for i in range(self.config.num_single_layers)
+ for _ in range(num_single_layers)
]
)
@@ -385,10 +386,6 @@ def unfuse_qkv_projections(self):
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -403,16 +400,16 @@ def forward(
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
@@ -444,6 +441,7 @@ def forward(
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
+
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
@@ -451,6 +449,7 @@ def forward(
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
+
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
@@ -474,26 +473,19 @@ def forward(
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
- for index_block, block in enumerate(self.transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
- **ckpt_kwargs,
)
else:
@@ -516,28 +508,15 @@ def custom_forward(*inputs):
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
temb,
image_rotary_emb,
- **ckpt_kwargs,
)
else:
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
new file mode 100644
index 000000000000..36f914f0b5c1
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -0,0 +1,1149 @@
+# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.loaders import FromOriginalModelMixin
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import FeedForward
+from ..attention_processor import Attention, AttentionProcessor
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ CombinedTimestepTextProjEmbeddings,
+ PixArtAlphaTextProjection,
+ TimestepEmbedding,
+ Timesteps,
+ get_1d_rotary_pos_embed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanVideoAttnProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ # 1. QKV projections
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ # 2. QK normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # 3. Rotational positional embeddings applied to latent stream
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ query = torch.cat(
+ [
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
+ query[:, :, -encoder_hidden_states.shape[1] :],
+ ],
+ dim=2,
+ )
+ key = torch.cat(
+ [
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
+ key[:, :, -encoder_hidden_states.shape[1] :],
+ ],
+ dim=2,
+ )
+ else:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # 4. Encoder condition QKV projection and normalization
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([query, encoder_query], dim=2)
+ key = torch.cat([key, encoder_key], dim=2)
+ value = torch.cat([value, encoder_value], dim=2)
+
+ # 5. Attention
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # 6. Output projection
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
+ )
+
+ if getattr(attn, "to_out", None) is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if getattr(attn, "to_add_out", None) is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanVideoPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: Union[int, Tuple[int, int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ super().__init__()
+
+ patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
+ return hidden_states
+
+
+class HunyuanVideoAdaNorm(nn.Module):
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
+ super().__init__()
+
+ out_features = out_features or 2 * in_features
+ self.linear = nn.Linear(in_features, out_features)
+ self.nonlinearity = nn.SiLU()
+
+ def forward(
+ self, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ temb = self.linear(self.nonlinearity(temb))
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
+ return gate_msa, gate_mlp
+
+
+class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
+
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ elif norm_type == "fp32_layer_norm":
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ emb: torch.Tensor,
+ token_replace_emb: torch.Tensor,
+ first_frame_num_tokens: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(emb))
+ token_replace_emb = self.linear(self.silu(token_replace_emb))
+
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
+ tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
+ 6, dim=1
+ )
+
+ norm_hidden_states = self.norm(hidden_states)
+ hidden_states_zero = (
+ norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
+ )
+ hidden_states_orig = (
+ norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ )
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
+
+ return (
+ hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ tr_gate_msa,
+ tr_shift_mlp,
+ tr_scale_mlp,
+ tr_gate_mlp,
+ )
+
+
+class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
+
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ emb: torch.Tensor,
+ token_replace_emb: torch.Tensor,
+ first_frame_num_tokens: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(emb))
+ token_replace_emb = self.linear(self.silu(token_replace_emb))
+
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
+ tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
+
+ norm_hidden_states = self.norm(hidden_states)
+ hidden_states_zero = (
+ norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
+ )
+ hidden_states_orig = (
+ norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ )
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
+
+ return hidden_states, gate_msa, tr_gate_msa
+
+
+class HunyuanVideoConditionEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ pooled_projection_dim: int,
+ guidance_embeds: bool,
+ image_condition_type: Optional[str] = None,
+ ):
+ super().__init__()
+
+ self.image_condition_type = image_condition_type
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ self.guidance_embedder = None
+ if guidance_embeds:
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(
+ self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+ pooled_projections = self.text_embedder(pooled_projection)
+ conditioning = timesteps_emb + pooled_projections
+
+ token_replace_emb = None
+ if self.image_condition_type == "token_replace":
+ token_replace_timestep = torch.zeros_like(timestep)
+ token_replace_proj = self.time_proj(token_replace_timestep)
+ token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
+ token_replace_emb = token_replace_emb + pooled_projections
+
+ if self.guidance_embedder is not None:
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
+ conditioning = conditioning + guidance_emb
+
+ return conditioning, token_replace_emb
+
+
+class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_width_ratio: str = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
+
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ )
+
+ gate_msa, gate_mlp = self.norm_out(temb)
+ hidden_states = hidden_states + attn_output * gate_msa
+
+ ff_output = self.ff(self.norm2(hidden_states))
+ hidden_states = hidden_states + ff_output * gate_mlp
+
+ return hidden_states
+
+
+class HunyuanVideoIndividualTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.refiner_blocks = nn.ModuleList(
+ [
+ HunyuanVideoIndividualTokenRefinerBlock(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> None:
+ self_attn_mask = None
+ if attention_mask is not None:
+ batch_size = attention_mask.shape[0]
+ seq_len = attention_mask.shape[1]
+ attention_mask = attention_mask.to(hidden_states.device).bool()
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
+ self_attn_mask[:, :, :, 0] = True
+
+ for block in self.refiner_blocks:
+ hidden_states = block(hidden_states, temb, self_attn_mask)
+
+ return hidden_states
+
+
+class HunyuanVideoTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
+ )
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_layers=num_layers,
+ mlp_width_ratio=mlp_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+ if attention_mask is None:
+ pooled_projections = hidden_states.mean(dim=1)
+ else:
+ original_dtype = hidden_states.dtype
+ mask_float = attention_mask.float().unsqueeze(-1)
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
+ pooled_projections = pooled_projections.to(original_dtype)
+
+ temb = self.time_text_embed(timestep, pooled_projections)
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
+
+ return hidden_states
+
+
+class HunyuanVideoRotaryPosEmbed(nn.Module):
+ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.rope_dim = rope_dim
+ self.theta = theta
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
+
+ axes_grids = []
+ for i in range(3):
+ # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
+ # original implementation creates it on CPU and then moves it to device. This results in numerical
+ # differences in layerwise debugging outputs, but visually it is the same.
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
+ axes_grids.append(grid)
+ grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
+
+ freqs = []
+ for i in range(3):
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
+ freqs.append(freq)
+
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
+ return freqs_cos, freqs_sin
+
+
+class HunyuanVideoSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+ mlp_dim = int(hidden_size * mlp_ratio)
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ bias=True,
+ processor=HunyuanVideoAttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ residual = hidden_states
+
+ # 1. Input normalization
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ norm_hidden_states, norm_encoder_hidden_states = (
+ norm_hidden_states[:, :-text_seq_length, :],
+ norm_hidden_states[:, -text_seq_length:, :],
+ )
+
+ # 2. Attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
+
+ # 3. Modulation and residual connection
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
+ hidden_states = hidden_states + residual
+
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, :-text_seq_length, :],
+ hidden_states[:, -text_seq_length:, :],
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanVideoTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ added_kv_proj_dim=hidden_size,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ context_pre_only=False,
+ bias=True,
+ processor=HunyuanVideoAttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=freqs_cis,
+ )
+
+ # 3. Modulation and residual connection
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # 4. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+ mlp_dim = int(hidden_size * mlp_ratio)
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ bias=True,
+ processor=HunyuanVideoAttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ token_replace_emb: torch.Tensor = None,
+ num_tokens: int = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ residual = hidden_states
+
+ # 1. Input normalization
+ norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ norm_hidden_states, norm_encoder_hidden_states = (
+ norm_hidden_states[:, :-text_seq_length, :],
+ norm_hidden_states[:, -text_seq_length:, :],
+ )
+
+ # 2. Attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
+
+ # 3. Modulation and residual connection
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+
+ proj_output = self.proj_out(hidden_states)
+ hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
+ hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
+ hidden_states = hidden_states + residual
+
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, :-text_seq_length, :],
+ hidden_states[:, -text_seq_length:, :],
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ added_kv_proj_dim=hidden_size,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ context_pre_only=False,
+ bias=True,
+ processor=HunyuanVideoAttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ token_replace_emb: torch.Tensor = None,
+ num_tokens: int = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ (
+ norm_hidden_states,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ tr_gate_msa,
+ tr_shift_mlp,
+ tr_scale_mlp,
+ tr_gate_mlp,
+ ) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=freqs_cis,
+ )
+
+ # 3. Modulation and residual connection
+ hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
+ hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
+ hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # 4. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+
+ hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
+ hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
+ hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of dual-stream blocks to use.
+ num_single_layers (`int`, defaults to `40`):
+ The number of layers of single-stream blocks to use.
+ num_refiner_layers (`int`, defaults to `2`):
+ The number of layers of refiner blocks to use.
+ mlp_ratio (`float`, defaults to `4.0`):
+ The ratio of the hidden layer size to the input size in the feedforward network.
+ patch_size (`int`, defaults to `2`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the tmeporal patches to use in the patch embedding layer.
+ qk_norm (`str`, defaults to `rms_norm`):
+ The normalization to use for the query and key projections in the attention layers.
+ guidance_embeds (`bool`, defaults to `True`):
+ Whether to use guidance embeddings in the model.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ pooled_projection_dim (`int`, defaults to `768`):
+ The dimension of the pooled projection of the text embeddings.
+ rope_theta (`float`, defaults to `256.0`):
+ The value of theta to use in the RoPE layer.
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions of the axes to use in the RoPE layer.
+ image_condition_type (`str`, *optional*, defaults to `None`):
+ The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
+ image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
+ tokens in the latent stream and apply conditioning.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
+ _no_split_modules = [
+ "HunyuanVideoTransformerBlock",
+ "HunyuanVideoSingleTransformerBlock",
+ "HunyuanVideoPatchEmbed",
+ "HunyuanVideoTokenRefiner",
+ ]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ num_attention_heads: int = 24,
+ attention_head_dim: int = 128,
+ num_layers: int = 20,
+ num_single_layers: int = 40,
+ num_refiner_layers: int = 2,
+ mlp_ratio: float = 4.0,
+ patch_size: int = 2,
+ patch_size_t: int = 1,
+ qk_norm: str = "rms_norm",
+ guidance_embeds: bool = True,
+ text_embed_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ rope_theta: float = 256.0,
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
+ image_condition_type: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+
+ supported_image_condition_types = ["latent_concat", "token_replace"]
+ if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
+ raise ValueError(
+ f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
+ )
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Latent and condition embedders
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+ self.context_embedder = HunyuanVideoTokenRefiner(
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
+ )
+
+ self.time_text_embed = HunyuanVideoConditionEmbedding(
+ inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
+ )
+
+ # 2. RoPE
+ self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
+
+ # 3. Dual stream transformer blocks
+ if image_condition_type == "token_replace":
+ self.transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideoTokenReplaceTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ else:
+ self.transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideoTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Single stream transformer blocks
+ if image_condition_type == "token_replace":
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideoTokenReplaceSingleTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+ else:
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideoSingleTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ # 5. Output projection
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ pooled_projections: torch.Tensor,
+ guidance: torch.Tensor = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p, p_t = self.config.patch_size, self.config.patch_size_t
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p
+ post_patch_width = width // p
+ first_frame_num_tokens = 1 * post_patch_height * post_patch_width
+
+ # 1. RoPE
+ image_rotary_emb = self.rope(hidden_states)
+
+ # 2. Conditional embeddings
+ temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
+
+ hidden_states = self.x_embedder(hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
+
+ # 3. Attention mask preparation
+ latent_sequence_length = hidden_states.shape[1]
+ condition_sequence_length = encoder_hidden_states.shape[1]
+ sequence_length = latent_sequence_length + condition_sequence_length
+ attention_mask = torch.zeros(
+ batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
+ ) # [B, N]
+
+ effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
+ effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
+
+ for i in range(batch_size):
+ attention_mask[i, : effective_sequence_length[i]] = True
+ # [B, 1, 1, N], for broadcasting across attention heads
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ token_replace_emb,
+ first_frame_num_tokens,
+ )
+
+ for block in self.single_transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ token_replace_emb,
+ first_frame_num_tokens,
+ )
+
+ else:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ token_replace_emb,
+ first_frame_num_tokens,
+ )
+
+ for block in self.single_transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask,
+ image_rotary_emb,
+ token_replace_emb,
+ first_frame_num_tokens,
+ )
+
+ # 5. Output projection
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
+ )
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(sample=hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py
new file mode 100644
index 000000000000..2ae2418098f6
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_ltx.py
@@ -0,0 +1,487 @@
+# Copyright 2024 The Genmo team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import FeedForward
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormSingle, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class LTXVideoAttentionProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class LTXVideoRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ base_num_frames: int = 20,
+ base_height: int = 2048,
+ base_width: int = 2048,
+ patch_size: int = 1,
+ patch_size_t: int = 1,
+ theta: float = 10000.0,
+ ) -> None:
+ super().__init__()
+
+ self.dim = dim
+ self.base_num_frames = base_num_frames
+ self.base_height = base_height
+ self.base_width = base_width
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.theta = theta
+
+ def _prepare_video_coords(
+ self,
+ batch_size: int,
+ num_frames: int,
+ height: int,
+ width: int,
+ rope_interpolation_scale: Tuple[torch.Tensor, float, float],
+ device: torch.device,
+ ) -> torch.Tensor:
+ # Always compute rope in fp32
+ grid_h = torch.arange(height, dtype=torch.float32, device=device)
+ grid_w = torch.arange(width, dtype=torch.float32, device=device)
+ grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
+ grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
+
+ if rope_interpolation_scale is not None:
+ grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
+ grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
+ grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
+
+ grid = grid.flatten(2, 4).transpose(1, 2)
+
+ return grid
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ num_frames: Optional[int] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
+ video_coords: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size = hidden_states.size(0)
+
+ if video_coords is None:
+ grid = self._prepare_video_coords(
+ batch_size,
+ num_frames,
+ height,
+ width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ device=hidden_states.device,
+ )
+ else:
+ grid = torch.stack(
+ [
+ video_coords[:, 0] / self.base_num_frames,
+ video_coords[:, 1] / self.base_height,
+ video_coords[:, 2] / self.base_width,
+ ],
+ dim=-1,
+ )
+
+ start = 1.0
+ end = self.theta
+ freqs = self.theta ** torch.linspace(
+ math.log(start, self.theta),
+ math.log(end, self.theta),
+ self.dim // 6,
+ device=hidden_states.device,
+ dtype=torch.float32,
+ )
+ freqs = freqs * math.pi / 2.0
+ freqs = freqs * (grid.unsqueeze(-1) * 2 - 1)
+ freqs = freqs.transpose(-1, -2).flatten(2)
+
+ cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
+ sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
+
+ if self.dim % 6 != 0:
+ cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6])
+ sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6])
+ cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
+ sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
+
+ return cos_freqs, sin_freqs
+
+
+@maybe_allow_in_graph
+class LTXVideoTransformerBlock(nn.Module):
+ r"""
+ Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
+
+ Args:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ qk_norm (`str`, defaults to `"rms_norm"`):
+ The normalization layer to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ cross_attention_dim: int,
+ qk_norm: str = "rms_norm_across_heads",
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = True,
+ attention_out_bias: bool = True,
+ eps: float = 1e-6,
+ elementwise_affine: bool = False,
+ ):
+ super().__init__()
+
+ self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ kv_heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ out_bias=attention_out_bias,
+ qk_norm=qk_norm,
+ processor=LTXVideoAttentionProcessor2_0(),
+ )
+
+ self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ kv_heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ qk_norm=qk_norm,
+ processor=LTXVideoAttentionProcessor2_0(),
+ )
+
+ self.ff = FeedForward(dim, activation_fn=activation_fn)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.size(0)
+ norm_hidden_states = self.norm1(hidden_states)
+
+ num_ada_params = self.scale_shift_table.shape[0]
+ ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+
+ attn_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + attn_hidden_states * gate_msa
+
+ attn_hidden_states = self.attn2(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_rotary_emb=None,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = hidden_states + attn_hidden_states
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + ff_output * gate_mlp
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
+ r"""
+ A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
+
+ Args:
+ in_channels (`int`, defaults to `128`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `128`):
+ The number of channels in the output.
+ patch_size (`int`, defaults to `1`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the tmeporal patches to use in the patch embedding layer.
+ num_attention_heads (`int`, defaults to `32`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ cross_attention_dim (`int`, defaults to `2048 `):
+ The number of channels for cross attention heads.
+ num_layers (`int`, defaults to `28`):
+ The number of layers of Transformer blocks to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
+ The normalization layer to use.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 128,
+ out_channels: int = 128,
+ patch_size: int = 1,
+ patch_size_t: int = 1,
+ num_attention_heads: int = 32,
+ attention_head_dim: int = 64,
+ cross_attention_dim: int = 2048,
+ num_layers: int = 28,
+ activation_fn: str = "gelu-approximate",
+ qk_norm: str = "rms_norm_across_heads",
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ caption_channels: int = 4096,
+ attention_bias: bool = True,
+ attention_out_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
+
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+
+ self.rope = LTXVideoRotaryPosEmbed(
+ dim=inner_dim,
+ base_num_frames=20,
+ base_height=2048,
+ base_width=2048,
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ theta=10000.0,
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ LTXVideoTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ qk_norm=qk_norm,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ attention_out_bias=attention_out_bias,
+ eps=norm_eps,
+ elementwise_affine=norm_elementwise_affine,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_attention_mask: torch.Tensor,
+ num_frames: Optional[int] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
+ video_coords: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> torch.Tensor:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ batch_size = hidden_states.size(0)
+ hidden_states = self.proj_in(hidden_states)
+
+ temb, embedded_timestep = self.time_embed(
+ timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=hidden_states.dtype,
+ )
+
+ temb = temb.view(batch_size, -1, temb.size(-1))
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
+
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
+
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ encoder_attention_mask,
+ )
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = hidden_states * (1 + scale) + shift
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
+
+
+def apply_rotary_emb(x, freqs):
+ cos, sin = freqs
+ x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py
new file mode 100644
index 000000000000..a873a6ec9444
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_lumina2.py
@@ -0,0 +1,548 @@
+# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...loaders.single_file_model import FromOriginalModelMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import LuminaFeedForward
+from ..attention_processor import Attention
+from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ cap_feat_dim: int = 2048,
+ frequency_embedding_size: int = 256,
+ norm_eps: float = 1e-5,
+ ) -> None:
+ super().__init__()
+
+ self.time_proj = Timesteps(
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
+ )
+
+ self.timestep_embedder = TimestepEmbedding(
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
+ )
+
+ self.caption_embedder = nn.Sequential(
+ RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
+ )
+
+ def forward(
+ self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states)
+ time_embed = self.timestep_embedder(timestep_proj)
+ caption_embed = self.caption_embedder(encoder_hidden_states)
+ return time_embed, caption_embed
+
+
+class Lumina2AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ base_sequence_length: Optional[int] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ # Get Query-Key-Value Pair
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query_dim = query.shape[-1]
+ inner_dim = key.shape[-1]
+ head_dim = query_dim // attn.heads
+ dtype = query.dtype
+
+ # Get key-value heads
+ kv_heads = inner_dim // head_dim
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, kv_heads, head_dim)
+ value = value.view(batch_size, -1, kv_heads, head_dim)
+
+ # Apply Query-Key Norm if needed
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
+
+ query, key = query.to(dtype), key.to(dtype)
+
+ # Apply proportional attention if true
+ if base_sequence_length is not None:
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
+ else:
+ softmax_scale = attn.scale
+
+ # perform Grouped-qurey Attention (GQA)
+ n_rep = attn.heads // kv_heads
+ if n_rep >= 1:
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ if attention_mask is not None:
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.type_as(query)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class Lumina2TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ num_kv_heads: int,
+ multiple_of: int,
+ ffn_dim_multiplier: float,
+ norm_eps: float,
+ modulation: bool = True,
+ ) -> None:
+ super().__init__()
+ self.head_dim = dim // num_attention_heads
+ self.modulation = modulation
+
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=dim // num_attention_heads,
+ qk_norm="rms_norm",
+ heads=num_attention_heads,
+ kv_heads=num_kv_heads,
+ eps=1e-5,
+ bias=False,
+ out_bias=False,
+ processor=Lumina2AttnProcessor2_0(),
+ )
+
+ self.feed_forward = LuminaFeedForward(
+ dim=dim,
+ inner_dim=4 * dim,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier,
+ )
+
+ if modulation:
+ self.norm1 = LuminaRMSNormZero(
+ embedding_dim=dim,
+ norm_eps=norm_eps,
+ norm_elementwise_affine=True,
+ )
+ else:
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
+
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ image_rotary_emb: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if self.modulation:
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + self.norm2(attn_output)
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
+
+ return hidden_states
+
+
+class Lumina2RotaryPosEmbed(nn.Module):
+ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+ self.axes_lens = axes_lens
+ self.patch_size = patch_size
+
+ self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
+
+ def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
+ freqs_cis = []
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
+ emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
+ freqs_cis.append(emb)
+ return freqs_cis
+
+ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
+ device = ids.device
+ if ids.device.type == "mps":
+ ids = ids.to("cpu")
+
+ result = []
+ for i in range(len(self.axes_dim)):
+ freqs = self.freqs_cis[i].to(ids.device)
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
+ return torch.cat(result, dim=-1).to(device)
+
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
+ batch_size, channels, height, width = hidden_states.shape
+ p = self.patch_size
+ post_patch_height, post_patch_width = height // p, width // p
+ image_seq_len = post_patch_height * post_patch_width
+ device = hidden_states.device
+
+ encoder_seq_len = attention_mask.shape[1]
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
+ seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
+ max_seq_len = max(seq_lengths)
+
+ # Create position IDs
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
+
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
+ # add caption position ids
+ position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
+ position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
+
+ # add image position ids
+ row_ids = (
+ torch.arange(post_patch_height, dtype=torch.int32, device=device)
+ .view(-1, 1)
+ .repeat(1, post_patch_width)
+ .flatten()
+ )
+ col_ids = (
+ torch.arange(post_patch_width, dtype=torch.int32, device=device)
+ .view(1, -1)
+ .repeat(post_patch_height, 1)
+ .flatten()
+ )
+ position_ids[i, cap_seq_len:seq_len, 1] = row_ids
+ position_ids[i, cap_seq_len:seq_len, 2] = col_ids
+
+ # Get combined rotary embeddings
+ freqs_cis = self._get_freqs_cis(position_ids)
+
+ # create separate rotary embeddings for captions and images
+ cap_freqs_cis = torch.zeros(
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
+ )
+ img_freqs_cis = torch.zeros(
+ batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
+ )
+
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
+ img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
+
+ # image patch embeddings
+ hidden_states = (
+ hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
+ .permute(0, 2, 4, 3, 5, 1)
+ .flatten(3)
+ .flatten(1, 2)
+ )
+
+ return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
+
+
+class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ r"""
+ Lumina2NextDiT: Diffusion model with a Transformer backbone.
+
+ Parameters:
+ sample_size (`int`): The width of the latent images. This is fixed during training since
+ it is used to learn a number of position embeddings.
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
+ in_channels (`int`, *optional*, defaults to 4):
+ The number of input channels for the model. Typically, this matches the number of channels in the input
+ images.
+ hidden_size (`int`, *optional*, defaults to 4096):
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
+ hidden representations.
+ num_layers (`int`, *optional*, default to 32):
+ The number of layers in the model. This defines the depth of the neural network.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
+ mechanisms are used.
+ num_kv_heads (`int`, *optional*, defaults to 8):
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
+ If None, it defaults to num_attention_heads.
+ multiple_of (`int`, *optional*, defaults to 256):
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
+ configurations.
+ ffn_dim_multiplier (`float`, *optional*):
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
+ the model configuration.
+ norm_eps (`float`, *optional*, defaults to 1e-5):
+ A small value added to the denominator for numerical stability in normalization layers.
+ scaling_factor (`float`, *optional*, defaults to 1.0):
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
+ overall scale of the model's operations.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["Lumina2TransformerBlock"]
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 128,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ out_channels: Optional[int] = None,
+ hidden_size: int = 2304,
+ num_layers: int = 26,
+ num_refiner_layers: int = 2,
+ num_attention_heads: int = 24,
+ num_kv_heads: int = 8,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ norm_eps: float = 1e-5,
+ scaling_factor: float = 1.0,
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
+ cap_feat_dim: int = 1024,
+ ) -> None:
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+
+ # 1. Positional, patch & conditional embeddings
+ self.rope_embedder = Lumina2RotaryPosEmbed(
+ theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
+ )
+
+ self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
+
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
+ hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
+ )
+
+ # 2. Noise and context refinement blocks
+ self.noise_refiner = nn.ModuleList(
+ [
+ Lumina2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=True,
+ )
+ for _ in range(num_refiner_layers)
+ ]
+ )
+
+ self.context_refiner = nn.ModuleList(
+ [
+ Lumina2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=False,
+ )
+ for _ in range(num_refiner_layers)
+ ]
+ )
+
+ # 3. Transformer blocks
+ self.layers = nn.ModuleList(
+ [
+ Lumina2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=True,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = LuminaLayerNormContinuous(
+ embedding_dim=hidden_size,
+ conditioning_embedding_dim=min(hidden_size, 1024),
+ elementwise_affine=False,
+ eps=1e-6,
+ bias=True,
+ out_dim=patch_size * patch_size * self.out_channels,
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # 1. Condition, positional & patch embedding
+ batch_size, _, height, width = hidden_states.shape
+
+ temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
+
+ (
+ hidden_states,
+ context_rotary_emb,
+ noise_rotary_emb,
+ rotary_emb,
+ encoder_seq_lengths,
+ seq_lengths,
+ ) = self.rope_embedder(hidden_states, encoder_attention_mask)
+
+ hidden_states = self.x_embedder(hidden_states)
+
+ # 2. Context & noise refinement
+ for layer in self.context_refiner:
+ encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
+
+ for layer in self.noise_refiner:
+ hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
+
+ # 3. Joint Transformer blocks
+ max_seq_len = max(seq_lengths)
+ use_mask = len(set(seq_lengths)) > 1
+
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
+ attention_mask[i, :seq_len] = True
+ joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
+ joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
+
+ hidden_states = joint_hidden_states
+
+ for layer in self.layers:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
+ )
+ else:
+ hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
+
+ # 4. Output norm & projection
+ hidden_states = self.norm_out(hidden_states, temb)
+
+ # 5. Unpatchify
+ p = self.config.patch_size
+ output = []
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
+ output.append(
+ hidden_states[i][encoder_seq_len:seq_len]
+ .view(height // p, width // p, p, p, self.out_channels)
+ .permute(4, 0, 2, 1, 3)
+ .flatten(3, 4)
+ .flatten(1, 2)
+ )
+ output = torch.stack(output, dim=0)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py
new file mode 100644
index 000000000000..e6532f080d72
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_mochi.py
@@ -0,0 +1,488 @@
+# Copyright 2024 The Genmo team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...loaders.single_file_model import FromOriginalModelMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import FeedForward
+from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
+from ..cache_utils import CacheMixin
+from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class MochiModulatedRMSNorm(nn.Module):
+ def __init__(self, eps: float):
+ super().__init__()
+
+ self.eps = eps
+ self.norm = RMSNorm(0, eps, False)
+
+ def forward(self, hidden_states, scale=None):
+ hidden_states_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+
+ hidden_states = self.norm(hidden_states)
+
+ if scale is not None:
+ hidden_states = hidden_states * scale
+
+ hidden_states = hidden_states.to(hidden_states_dtype)
+
+ return hidden_states
+
+
+class MochiLayerNormContinuous(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ eps=1e-5,
+ bias=True,
+ ):
+ super().__init__()
+
+ # AdaLN
+ self.silu = nn.SiLU()
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
+ self.norm = MochiModulatedRMSNorm(eps=eps)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ conditioning_embedding: torch.Tensor,
+ ) -> torch.Tensor:
+ input_dtype = x.dtype
+
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+ scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
+ x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
+
+ return x.to(input_dtype)
+
+
+class MochiRMSNormZero(nn.Module):
+ r"""
+ Adaptive RMS Norm used in Mochi.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ """
+
+ def __init__(
+ self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
+ ) -> None:
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, hidden_dim)
+ self.norm = RMSNorm(0, eps, False)
+
+ def forward(
+ self, hidden_states: torch.Tensor, emb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ hidden_states_dtype = hidden_states.dtype
+
+ emb = self.linear(self.silu(emb))
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
+ hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32))
+ hidden_states = hidden_states.to(hidden_states_dtype)
+
+ return hidden_states, gate_msa, scale_mlp, gate_mlp
+
+
+@maybe_allow_in_graph
+class MochiTransformerBlock(nn.Module):
+ r"""
+ Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
+
+ Args:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ qk_norm (`str`, defaults to `"rms_norm"`):
+ The normalization layer to use.
+ activation_fn (`str`, defaults to `"swiglu"`):
+ Activation function to use in feed-forward.
+ context_pre_only (`bool`, defaults to `False`):
+ Whether or not to process context-related conditions with additional layers.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ pooled_projection_dim: int,
+ qk_norm: str = "rms_norm",
+ activation_fn: str = "swiglu",
+ context_pre_only: bool = False,
+ eps: float = 1e-6,
+ ) -> None:
+ super().__init__()
+
+ self.context_pre_only = context_pre_only
+ self.ff_inner_dim = (4 * dim * 2) // 3
+ self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
+
+ self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
+
+ if not context_pre_only:
+ self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
+ else:
+ self.norm1_context = MochiLayerNormContinuous(
+ embedding_dim=pooled_projection_dim,
+ conditioning_embedding_dim=dim,
+ eps=eps,
+ )
+
+ self.attn1 = MochiAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=False,
+ added_kv_proj_dim=pooled_projection_dim,
+ added_proj_bias=False,
+ out_dim=dim,
+ out_context_dim=pooled_projection_dim,
+ context_pre_only=context_pre_only,
+ processor=MochiAttnProcessor2_0(),
+ eps=1e-5,
+ )
+
+ # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
+ self.norm2 = MochiModulatedRMSNorm(eps=eps)
+ self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
+
+ self.norm3 = MochiModulatedRMSNorm(eps)
+ self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
+
+ self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
+ self.ff_context = None
+ if not context_pre_only:
+ self.ff_context = FeedForward(
+ pooled_projection_dim,
+ inner_dim=self.ff_context_inner_dim,
+ activation_fn=activation_fn,
+ bias=False,
+ )
+
+ self.norm4 = MochiModulatedRMSNorm(eps=eps)
+ self.norm4_context = MochiModulatedRMSNorm(eps=eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
+
+ if not self.context_pre_only:
+ norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
+ encoder_hidden_states, temb
+ )
+ else:
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
+
+ attn_hidden_states, context_attn_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=encoder_attention_mask,
+ )
+
+ hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
+ norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
+
+ if not self.context_pre_only:
+ encoder_hidden_states = encoder_hidden_states + self.norm2_context(
+ context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
+ )
+ norm_encoder_hidden_states = self.norm3_context(
+ encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
+ )
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + self.norm4_context(
+ context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
+ )
+
+ return hidden_states, encoder_hidden_states
+
+
+class MochiRoPE(nn.Module):
+ r"""
+ RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
+
+ Args:
+ base_height (`int`, defaults to `192`):
+ Base height used to compute interpolation scale for rotary positional embeddings.
+ base_width (`int`, defaults to `192`):
+ Base width used to compute interpolation scale for rotary positional embeddings.
+ """
+
+ def __init__(self, base_height: int = 192, base_width: int = 192) -> None:
+ super().__init__()
+
+ self.target_area = base_height * base_width
+
+ def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
+ edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
+ return (edges[:-1] + edges[1:]) / 2
+
+ def _get_positions(
+ self,
+ num_frames: int,
+ height: int,
+ width: int,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> torch.Tensor:
+ scale = (self.target_area / (height * width)) ** 0.5
+
+ t = torch.arange(num_frames, device=device, dtype=dtype)
+ h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
+ w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
+
+ grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
+
+ positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
+ return positions
+
+ def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
+ with torch.autocast(freqs.device.type, torch.float32):
+ # Always run ROPE freqs computation in FP32
+ freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
+
+ freqs_cos = torch.cos(freqs)
+ freqs_sin = torch.sin(freqs)
+ return freqs_cos, freqs_sin
+
+ def forward(
+ self,
+ pos_frequencies: torch.Tensor,
+ num_frames: int,
+ height: int,
+ width: int,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ pos = self._get_positions(num_frames, height, width, device, dtype)
+ rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
+ return rope_cos, rope_sin
+
+
+@maybe_allow_in_graph
+class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
+
+ Args:
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `48`):
+ The number of layers of Transformer blocks to use.
+ in_channels (`int`, defaults to `12`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output.
+ qk_norm (`str`, defaults to `"rms_norm"`):
+ The normalization layer to use.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ time_embed_dim (`int`, defaults to `256`):
+ Output dimension of timestep embeddings.
+ activation_fn (`str`, defaults to `"swiglu"`):
+ Activation function to use in feed-forward.
+ max_sequence_length (`int`, defaults to `256`):
+ The maximum sequence length of text embeddings supported.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["MochiTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ num_attention_heads: int = 24,
+ attention_head_dim: int = 128,
+ num_layers: int = 48,
+ pooled_projection_dim: int = 1536,
+ in_channels: int = 12,
+ out_channels: Optional[int] = None,
+ qk_norm: str = "rms_norm",
+ text_embed_dim: int = 4096,
+ time_embed_dim: int = 256,
+ activation_fn: str = "swiglu",
+ max_sequence_length: int = 256,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ pos_embed_type=None,
+ )
+
+ self.time_embed = MochiCombinedTimestepCaptionEmbedding(
+ embedding_dim=inner_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ text_embed_dim=text_embed_dim,
+ time_embed_dim=time_embed_dim,
+ num_attention_heads=8,
+ )
+
+ self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0))
+ self.rope = MochiRoPE()
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ MochiTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ pooled_projection_dim=pooled_projection_dim,
+ qk_norm=qk_norm,
+ activation_fn=activation_fn,
+ context_pre_only=i == num_layers - 1,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(
+ inner_dim,
+ inner_dim,
+ elementwise_affine=False,
+ eps=1e-6,
+ norm_type="layer_norm",
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_attention_mask: torch.Tensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> torch.Tensor:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p = self.config.patch_size
+
+ post_patch_height = height // p
+ post_patch_width = width // p
+
+ temb, encoder_hidden_states = self.time_embed(
+ timestep,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ hidden_dtype=hidden_states.dtype,
+ )
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ hidden_states = self.patch_embed(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
+
+ image_rotary_emb = self.rope(
+ self.pos_frequencies,
+ num_frames,
+ post_patch_height,
+ post_patch_width,
+ device=hidden_states.device,
+ dtype=torch.float32,
+ )
+
+ for i, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ encoder_attention_mask,
+ image_rotary_emb,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ encoder_attention_mask=encoder_attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
+ hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
+ output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
new file mode 100644
index 000000000000..8d5d1b3f8fea
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -0,0 +1,469 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ..attention_processor import Attention
+from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNorm, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class OmniGenFeedForward(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int):
+ super().__init__()
+
+ self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.activation_fn = nn.SiLU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ up_states = self.gate_up_proj(hidden_states)
+ gate, up_states = up_states.chunk(2, dim=-1)
+ up_states = up_states * self.activation_fn(gate)
+ return self.down_proj(up_states)
+
+
+class OmniGenPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 4,
+ embed_dim: int = 768,
+ bias: bool = True,
+ interpolation_scale: float = 1,
+ pos_embed_max_size: int = 192,
+ base_size: int = 64,
+ ):
+ super().__init__()
+
+ self.output_image_proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.input_image_proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+
+ self.patch_size = patch_size
+ self.interpolation_scale = interpolation_scale
+ self.pos_embed_max_size = pos_embed_max_size
+
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim,
+ self.pos_embed_max_size,
+ base_size=base_size,
+ interpolation_scale=self.interpolation_scale,
+ output_type="pt",
+ )
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
+
+ def _cropped_pos_embed(self, height, width):
+ """Crops positional embeddings for SD3 compatibility."""
+ if self.pos_embed_max_size is None:
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ if height > self.pos_embed_max_size:
+ raise ValueError(
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+ if width > self.pos_embed_max_size:
+ raise ValueError(
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
+ return spatial_pos_embed
+
+ def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor:
+ if is_input_image:
+ hidden_states = self.input_image_proj(hidden_states)
+ else:
+ hidden_states = self.output_image_proj(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+ return hidden_states
+
+ def forward(
+ self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None
+ ) -> torch.Tensor:
+ if isinstance(hidden_states, list):
+ if padding_latent is None:
+ padding_latent = [None] * len(hidden_states)
+ patched_latents = []
+ for sub_latent, padding in zip(hidden_states, padding_latent):
+ height, width = sub_latent.shape[-2:]
+ sub_latent = self._patch_embeddings(sub_latent, is_input_image)
+ pos_embed = self._cropped_pos_embed(height, width)
+ sub_latent = sub_latent + pos_embed
+ if padding is not None:
+ sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
+ patched_latents.append(sub_latent)
+ else:
+ height, width = hidden_states.shape[-2:]
+ pos_embed = self._cropped_pos_embed(height, width)
+ hidden_states = self._patch_embeddings(hidden_states, is_input_image)
+ patched_latents = hidden_states + pos_embed
+
+ return patched_latents
+
+
+class OmniGenSuScaledRotaryEmbedding(nn.Module):
+ def __init__(
+ self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ self.short_factor = rope_scaling["short_factor"]
+ self.long_factor = rope_scaling["long_factor"]
+ self.original_max_position_embeddings = original_max_position_embeddings
+
+ def forward(self, hidden_states, position_ids):
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.original_max_position_embeddings:
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device)
+ else:
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device)
+
+ inv_freq_shape = (
+ torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim
+ )
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
+
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = hidden_states.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)[0]
+
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
+ if scale <= 1.0:
+ scaling_factor = 1.0
+ else:
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
+
+ cos = emb.cos() * scaling_factor
+ sin = emb.sin() * scaling_factor
+ return cos, sin
+
+
+class OmniGenAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the OmniGen model.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ # Get Query-Key-Value Pair
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ bsz, q_len, query_dim = query.size()
+ inner_dim = key.shape[-1]
+ head_dim = query_dim // attn.heads
+
+ # Get key-value heads
+ kv_heads = inner_dim // head_dim
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2)
+ key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
+ hidden_states = hidden_states.transpose(1, 2).type_as(query)
+ hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim)
+ hidden_states = attn.to_out[0](hidden_states)
+ return hidden_states
+
+
+class OmniGenBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ intermediate_size: int,
+ rms_norm_eps: float,
+ ) -> None:
+ super().__init__()
+
+ self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.self_attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=hidden_size,
+ dim_head=hidden_size // num_attention_heads,
+ heads=num_attention_heads,
+ kv_heads=num_key_value_heads,
+ bias=False,
+ out_dim=hidden_size,
+ out_bias=False,
+ processor=OmniGenAttnProcessor2_0(),
+ )
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.mlp = OmniGenFeedForward(hidden_size, intermediate_size)
+
+ def forward(
+ self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor
+ ) -> torch.Tensor:
+ # 1. Attention
+ norm_hidden_states = self.input_layernorm(hidden_states)
+ attn_output = self.self_attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + attn_output
+
+ # 2. Feed Forward
+ norm_hidden_states = self.post_attention_layernorm(hidden_states)
+ ff_output = self.mlp(norm_hidden_states)
+ hidden_states = hidden_states + ff_output
+ return hidden_states
+
+
+class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
+ """
+ The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).
+
+ Parameters:
+ in_channels (`int`, defaults to `4`):
+ The number of channels in the input.
+ patch_size (`int`, defaults to `2`):
+ The size of the spatial patches to use in the patch embedding layer.
+ hidden_size (`int`, defaults to `3072`):
+ The dimensionality of the hidden layers in the model.
+ rms_norm_eps (`float`, defaults to `1e-5`):
+ Eps for RMSNorm layer.
+ num_attention_heads (`int`, defaults to `32`):
+ The number of heads to use for multi-head attention.
+ num_key_value_heads (`int`, defaults to `32`):
+ The number of heads to use for keys and values in multi-head attention.
+ intermediate_size (`int`, defaults to `8192`):
+ Dimension of the hidden layer in FeedForward layers.
+ num_layers (`int`, default to `32`):
+ The number of layers of transformer blocks to use.
+ pad_token_id (`int`, default to `32000`):
+ The id of the padding token.
+ vocab_size (`int`, default to `32064`):
+ The size of the vocabulary of the embedding vocabulary.
+ rope_base (`int`, default to `10000`):
+ The default theta value to use when creating RoPE.
+ rope_scaling (`Dict`, optional):
+ The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`.
+ pos_embed_max_size (`int`, default to `192`):
+ The maximum size of the positional embeddings.
+ time_step_dim (`int`, default to `256`):
+ Output dimension of timestep embeddings.
+ flip_sin_to_cos (`bool`, default to `True`):
+ Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings.
+ downscale_freq_shift (`int`, default to `0`):
+ The frequency shift to use when downscaling the timestep embeddings.
+ timestep_activation_fn (`str`, default to `silu`):
+ The activation function to use for the timestep embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["OmniGenBlock"]
+ _skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ patch_size: int = 2,
+ hidden_size: int = 3072,
+ rms_norm_eps: float = 1e-5,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 32,
+ intermediate_size: int = 8192,
+ num_layers: int = 32,
+ pad_token_id: int = 32000,
+ vocab_size: int = 32064,
+ max_position_embeddings: int = 131072,
+ original_max_position_embeddings: int = 4096,
+ rope_base: int = 10000,
+ rope_scaling: Dict = None,
+ pos_embed_max_size: int = 192,
+ time_step_dim: int = 256,
+ flip_sin_to_cos: bool = True,
+ downscale_freq_shift: int = 0,
+ timestep_activation_fn: str = "silu",
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+
+ self.patch_embedding = OmniGenPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=hidden_size,
+ pos_embed_max_size=pos_embed_max_size,
+ )
+
+ self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift)
+ self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
+ self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
+
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
+ self.rope = OmniGenSuScaledRotaryEmbedding(
+ hidden_size // num_attention_heads,
+ max_position_embeddings=max_position_embeddings,
+ original_max_position_embeddings=original_max_position_embeddings,
+ base=rope_base,
+ rope_scaling=rope_scaling,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps)
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
+ self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def _get_multimodal_embeddings(
+ self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict
+ ) -> Optional[torch.Tensor]:
+ if input_ids is None:
+ return None
+
+ input_img_latents = [x.to(self.dtype) for x in input_img_latents]
+ condition_tokens = self.embed_tokens(input_ids)
+ input_img_inx = 0
+ input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ # replace the placeholder in text tokens with the image embedding.
+ condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
+ condition_tokens.dtype
+ )
+ input_img_inx += 1
+ return condition_tokens
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.FloatTensor],
+ input_ids: torch.Tensor,
+ input_img_latents: List[torch.Tensor],
+ input_image_sizes: Dict[int, List[int]],
+ attention_mask: torch.Tensor,
+ position_ids: torch.Tensor,
+ return_dict: bool = True,
+ ) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]:
+ batch_size, num_channels, height, width = hidden_states.shape
+ p = self.config.patch_size
+ post_patch_height, post_patch_width = height // p, width // p
+
+ # 1. Patch & Timestep & Conditional Embedding
+ hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
+ num_tokens_for_output_image = hidden_states.size(1)
+
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states)
+ time_token = self.time_token(timestep_proj).unsqueeze(1)
+ temb = self.t_embedder(timestep_proj)
+
+ condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes)
+ if condition_tokens is not None:
+ hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
+ else:
+ hidden_states = torch.cat([time_token, hidden_states], dim=1)
+
+ seq_length = hidden_states.size(1)
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # 2. Attention mask preprocessing
+ if attention_mask is not None and attention_mask.dim() == 3:
+ dtype = hidden_states.dtype
+ min_dtype = torch.finfo(dtype).min
+ attention_mask = (1 - attention_mask) * min_dtype
+ attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states)
+
+ # 3. Rotary position embedding
+ image_rotary_emb = self.rope(hidden_states, position_ids)
+
+ # 4. Transformer blocks
+ for block in self.layers:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, attention_mask, image_rotary_emb
+ )
+ else:
+ hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb)
+
+ # 5. Output norm & projection
+ hidden_states = self.norm(hidden_states)
+ hidden_states = hidden_states[:, -num_tokens_for_output_image:]
+ hidden_states = self.norm_out(hidden_states, temb=temb)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1)
+ output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index b28350b8ed9c..e41fad220de6 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -11,20 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...models.attention import JointTransformerBlock
-from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
+from ...models.attention import FeedForward, JointTransformerBlock
+from ...models.attention_processor import (
+ Attention,
+ AttentionProcessor,
+ FusedJointAttnProcessor2_0,
+ JointAttnProcessor2_0,
+)
from ...models.modeling_utils import ModelMixin
-from ...models.normalization import AdaLayerNormContinuous
-from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
@@ -32,28 +36,86 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
- """
- The Transformer model introduced in Stable Diffusion 3.
+@maybe_allow_in_graph
+class SD3SingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ ):
+ super().__init__()
- Reference: https://arxiv.org/abs/2403.03206
+ self.norm1 = AdaLayerNormZero(dim)
+ self.attn = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=JointAttnProcessor2_0(),
+ eps=1e-6,
+ )
- Parameters:
- sample_size (`int`): The width of the latent images. This is fixed during training since
- it is used to learn a number of position embeddings.
- patch_size (`int`): Patch size to turn the input data into small patches.
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
- num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
- out_channels (`int`, defaults to 16): Number of output channels.
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
+ # 1. Attention
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+ # 2. Feed Forward
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ hidden_states = hidden_states + ff_output
+
+ return hidden_states
+
+
+class SD3Transformer2DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
+):
+ """
+ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
+
+ Parameters:
+ sample_size (`int`, defaults to `128`):
+ The width/height of the latents. This is fixed during training since it is used to learn a number of
+ position embeddings.
+ patch_size (`int`, defaults to `2`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `16`):
+ The number of latent channels in the input.
+ num_layers (`int`, defaults to `18`):
+ The number of layers of transformer blocks to use.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ num_attention_heads (`int`, defaults to `18`):
+ The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, defaults to `4096`):
+ The embedding dimension to use for joint text-image attention.
+ caption_projection_dim (`int`, defaults to `1152`):
+ The embedding dimension of caption embeddings.
+ pooled_projection_dim (`int`, defaults to `2048`):
+ The embedding dimension of pooled text projections.
+ out_channels (`int`, defaults to `16`):
+ The number of latent channels in the output.
+ pos_embed_max_size (`int`, defaults to `96`):
+ The maximum latent height/width of positional embeddings.
+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
+ The number of dual-stream transformer blocks to use.
+ qk_norm (`str`, *optional*, defaults to `None`):
+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
"""
_supports_gradient_checkpointing = True
+ _no_split_modules = ["JointTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
@@ -75,36 +137,33 @@ def __init__(
qk_norm: Optional[str] = None,
):
super().__init__()
- default_out_channels = in_channels
- self.out_channels = out_channels if out_channels is not None else default_out_channels
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.out_channels = out_channels if out_channels is not None else in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = PatchEmbed(
- height=self.config.sample_size,
- width=self.config.sample_size,
- patch_size=self.config.patch_size,
- in_channels=self.config.in_channels,
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
- embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
- self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
- # `attention_head_dim` is doubled to account for the mixing.
- # It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
- num_attention_heads=self.config.num_attention_heads,
- attention_head_dim=self.config.attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
context_pre_only=i == num_layers - 1,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
- for i in range(self.config.num_layers)
+ for i in range(num_layers)
]
)
@@ -255,33 +314,30 @@ def unfuse_qkv_projections(self):
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- pooled_projections: torch.FloatTensor = None,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ skip_layers: Optional[List[int]] = None,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Args:
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
- from the embeddings of input conditions.
- timestep ( `torch.LongTensor`):
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
+ Embeddings projected from the embeddings of input conditions.
+ timestep (`torch.LongTensor`):
Used to indicate denoising step.
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
@@ -290,6 +346,8 @@ def forward(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
+ skip_layers (`list` of `int`, *optional*):
+ A list of layer indices to skip during the forward pass.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -316,36 +374,36 @@ def forward(
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
- for index_block, block in enumerate(self.transformer_blocks):
- if self.training and self.gradient_checkpointing:
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
+ joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
- return custom_forward
+ for index_block, block in enumerate(self.transformer_blocks):
+ # Skip specified layers
+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
hidden_states,
encoder_hidden_states,
temb,
- **ckpt_kwargs,
+ joint_attention_kwargs,
)
-
- else:
+ elif not is_skip:
encoder_hidden_states, hidden_states = block(
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
- interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
- hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
+ interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
+ hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py
index c0c5467050dd..5580d0f70f9f 100644
--- a/src/diffusers/models/transformers/transformer_temporal.py
+++ b/src/diffusers/models/transformers/transformer_temporal.py
@@ -67,6 +67,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
The maximum length of the sequence over which to apply positional embeddings.
"""
+ _skip_layerwise_casting_patterns = ["norm"]
+
@register_to_config
def __init__(
self,
@@ -340,20 +342,12 @@ def forward(
# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
- if self.training and self.gradient_checkpointing:
- hidden_states = torch.utils.checkpoint.checkpoint(
- block,
- hidden_states,
- None,
- encoder_hidden_states,
- None,
- use_reentrant=False,
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, None, encoder_hidden_states, None
)
else:
- hidden_states = block(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- )
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
new file mode 100644
index 000000000000..aa03e97093aa
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -0,0 +1,469 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import FeedForward
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanAttnProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ encoder_hidden_states_img = encoder_hidden_states[:, :257]
+ encoder_hidden_states = encoder_hidden_states[:, 257:]
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
+ return x_out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, rotary_emb)
+ key = apply_rotary_emb(key, rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ hidden_states_img = F.scaled_dot_product_attention(
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class WanImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class WanTimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+class WanRotaryPosEmbed(nn.Module):
+ def __init__(
+ self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+
+ freqs = []
+ for dim in [t_dim, h_dim, w_dim]:
+ freq = get_1d_rotary_pos_embed(
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
+ )
+ freqs.append(freq)
+ self.freqs = torch.cat(freqs, dim=1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ self.freqs = self.freqs.to(hidden_states.device)
+ freqs = self.freqs.split_with_sizes(
+ [
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
+ self.attention_head_dim // 6,
+ self.attention_head_dim // 6,
+ ],
+ dim=1,
+ )
+
+ freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+ return freqs
+
+
+class WanTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ qk_norm=qk_norm,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ processor=WanAttnProcessor2_0(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ qk_norm=qk_norm,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ added_kv_proj_dim=added_kv_proj_dim,
+ added_proj_bias=True,
+ processor=WanAttnProcessor2_0(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ A Transformer model for video-like data used in the Wan model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ add_img_emb (`bool`, defaults to `False`):
+ Whether to use img_emb.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["WanTransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ # image_embedding_dim=1280 for I2V model
+ self.condition_embedder = WanTimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ WanTransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image
+ )
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+
+ # 5. Output norm, projection & unpatchify
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py
index 8efabd98ee7d..ce496fd6baf8 100644
--- a/src/diffusers/models/unets/unet_1d.py
+++ b/src/diffusers/models/unets/unet_1d.py
@@ -71,6 +71,8 @@ class UNet1DModel(ModelMixin, ConfigMixin):
Experimental feature for using a UNet without upsampling.
"""
+ _skip_layerwise_casting_patterns = ["norm"]
+
@register_to_config
def __init__(
self,
@@ -223,7 +225,7 @@ def forward(
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
- timestep_embed = self.time_mlp(timestep_embed)
+ timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
diff --git a/src/diffusers/models/unets/unet_1d_blocks.py b/src/diffusers/models/unets/unet_1d_blocks.py
index 8fc27e94c474..f08e6070845e 100644
--- a/src/diffusers/models/unets/unet_1d_blocks.py
+++ b/src/diffusers/models/unets/unet_1d_blocks.py
@@ -217,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tens
if self.upsample:
hidden_states = self.upsample(hidden_states)
if self.downsample:
- self.downsample = self.downsample(hidden_states)
+ hidden_states = self.downsample(hidden_states)
return hidden_states
diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py
index 5972505f2897..448ec051a032 100644
--- a/src/diffusers/models/unets/unet_2d.py
+++ b/src/diffusers/models/unets/unet_2d.py
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
- Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -89,6 +89,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
conditioning with `class_embed_type` equal to `None`.
"""
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["norm"]
+
@register_to_config
def __init__(
self,
@@ -97,9 +100,11 @@ def __init__(
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
+ mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
@@ -122,7 +127,7 @@ def __init__(
super().__init__()
self.sample_size = sample_size
- time_embed_dim = block_out_channels[0] * 4
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
@@ -191,19 +196,22 @@ def __init__(
self.down_blocks.append(down_block)
# mid
- self.mid_block = UNetMidBlock2D(
- in_channels=block_out_channels[-1],
- temb_channels=time_embed_dim,
- dropout=dropout,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
- resnet_groups=norm_num_groups,
- attn_groups=attn_norm_num_groups,
- add_attention=add_attention,
- )
+ if mid_block_type is None:
+ self.mid_block = None
+ else:
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ attn_groups=attn_norm_num_groups,
+ add_attention=add_attention,
+ )
# up
reversed_block_out_channels = list(reversed(block_out_channels))
@@ -232,7 +240,6 @@ def __init__(
dropout=dropout,
)
self.up_blocks.append(up_block)
- prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
@@ -315,7 +322,8 @@ def forward(
down_block_res_samples += res_samples
# 4. mid
- sample = self.mid_block(sample, emb)
+ if self.mid_block is not None:
+ sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py
index 93a0a82cdcff..e082d524e766 100644
--- a/src/diffusers/models/unets/unet_2d_blocks.py
+++ b/src/diffusers/models/unets/unet_2d_blocks.py
@@ -18,7 +18,7 @@
import torch.nn.functional as F
from torch import nn
-from ...utils import deprecate, is_torch_version, logging
+from ...utils import deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..activations import get_activation
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
@@ -731,12 +731,19 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
+ self.gradient_checkpointing = False
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if attn is not None:
- hidden_states = attn(hidden_states, temb=temb)
- hidden_states = resnet(hidden_states, temb)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+ else:
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -859,18 +866,7 @@ def forward(
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -879,12 +875,7 @@ def custom_forward(*inputs):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
@@ -1116,6 +1107,8 @@ def __init__(
else:
self.downsamplers = None
+ self.gradient_checkpointing = False
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -1130,9 +1123,14 @@ def forward(
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states, **cross_attention_kwargs)
- output_states = output_states + (hidden_states,)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+ output_states = output_states + (hidden_states,)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+ output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
@@ -1257,24 +1255,8 @@ def forward(
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1371,22 +1353,8 @@ def forward(
output_states = ()
for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -1859,22 +1827,8 @@ def forward(
output_states = ()
for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -2011,18 +1965,8 @@ def forward(
mask = attention_mask
for resnet, attn in zip(self.resnets, self.attentions):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2106,22 +2050,8 @@ def forward(
output_states = ()
for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -2215,23 +2145,11 @@ def forward(
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ resnet,
hidden_states,
temb,
- **ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
@@ -2354,6 +2272,7 @@ def __init__(
else:
self.upsamplers = None
+ self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(
@@ -2375,8 +2294,12 @@ def forward(
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- hidden_states = resnet(hidden_states, temb)
- hidden_states = attn(hidden_states)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -2520,24 +2443,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2653,22 +2560,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -3183,22 +3076,8 @@ def forward(
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -3341,18 +3220,8 @@ def forward(
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -3444,22 +3313,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -3572,23 +3427,11 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
for resnet, attn in zip(self.resnets, self.attentions):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ resnet,
hidden_states,
temb,
- **ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 15f126d686ee..2fd15f6f91e0 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -172,11 +172,12 @@ class conditioning with `class_embed_type` equal to `None`.
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
+ _skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
self,
- sample_size: Optional[int] = None,
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
@@ -928,10 +929,6 @@ def fn_recursive_set_attention_slice(
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -1015,10 +1012,11 @@ def get_time_embed(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py
index 8b472a89e13d..8d7614a23383 100644
--- a/src/diffusers/models/unets/unet_3d_blocks.py
+++ b/src/diffusers/models/unets/unet_3d_blocks.py
@@ -17,7 +17,7 @@
import torch
from torch import nn
-from ...utils import deprecate, is_torch_version, logging
+from ...utils import deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import Attention
from ..resnet import (
@@ -1078,31 +1078,14 @@ def forward(
)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if self.training and self.gradient_checkpointing: # TODO
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- **ckpt_kwargs,
- )
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
else:
hidden_states = attn(
hidden_states,
@@ -1110,11 +1093,7 @@ def custom_forward(*inputs):
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
return hidden_states
@@ -1168,35 +1147,10 @@ def forward(
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = ()
for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- use_reentrant=False,
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
else:
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
output_states = output_states + (hidden_states,)
@@ -1281,25 +1235,8 @@ def forward(
blocks = list(zip(self.resnets, self.attentions))
for resnet, attn in blocks:
- if self.training and self.gradient_checkpointing: # TODO
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
hidden_states = attn(
hidden_states,
@@ -1308,11 +1245,7 @@ def custom_forward(*inputs):
return_dict=False,
)[0]
else:
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1375,6 +1308,7 @@ def forward(
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
+ upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet in self.resnets:
# pop res hidden states
@@ -1383,39 +1317,14 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- use_reentrant=False,
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
else:
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
+ hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
@@ -1485,6 +1394,7 @@ def forward(
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
+ upsample_size: Optional[int] = None,
) -> torch.Tensor:
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
@@ -1493,25 +1403,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing: # TODO
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1519,11 +1412,7 @@ def custom_forward(*inputs):
return_dict=False,
)[0]
else:
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
+ hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1533,6 +1422,6 @@ def custom_forward(*inputs):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
+ hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py
index 3081fdc4700c..a148cf6cbe06 100644
--- a/src/diffusers/models/unets/unet_3d_condition.py
+++ b/src/diffusers/models/unets/unet_3d_condition.py
@@ -37,11 +37,7 @@
from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
- CrossAttnDownBlock3D,
- CrossAttnUpBlock3D,
- DownBlock3D,
UNetMidBlock3DCrossAttn,
- UpBlock3D,
get_down_block,
get_up_block,
)
@@ -97,6 +93,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
"""
_supports_gradient_checkpointing = False
+ _skip_layerwise_casting_patterns = ["norm", "time_embedding"]
@register_to_config
def __init__(
@@ -471,10 +468,6 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor)
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
- module.gradient_checkpointing = value
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -624,10 +617,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -644,8 +638,10 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
+ num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
+ )
# 2. pre-process
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py
index 6ab3a577b892..c275e16744f4 100644
--- a/src/diffusers/models/unets/unet_i2vgen_xl.py
+++ b/src/diffusers/models/unets/unet_i2vgen_xl.py
@@ -35,11 +35,7 @@
from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
- CrossAttnDownBlock3D,
- CrossAttnUpBlock3D,
- DownBlock3D,
UNetMidBlock3DCrossAttn,
- UpBlock3D,
get_down_block,
get_up_block,
)
@@ -436,11 +432,6 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor)
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
- module.gradient_checkpointing = value
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -575,10 +566,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timesteps, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -600,7 +592,7 @@ def forward(
# 3. time + FPS embeddings.
emb = t_emb + fps_emb
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# 4. context embeddings.
# The context embeddings consist of both text embeddings from the input prompt
@@ -628,7 +620,7 @@ def forward(
image_emb = self.context_embedding(image_embeddings)
image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
context_emb = torch.cat([context_emb, image_emb], dim=1)
- context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
+ context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
image_latents.shape[0] * image_latents.shape[2],
diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py
index f611e7d82b1d..73bf0020b481 100644
--- a/src/diffusers/models/unets/unet_kandinsky3.py
+++ b/src/diffusers/models/unets/unet_kandinsky3.py
@@ -205,10 +205,6 @@ def set_default_attn_processor(self):
"""
self.set_attn_processor(AttnProcessor())
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py
index 6125feba5899..bd83024c9b7c 100644
--- a/src/diffusers/models/unets/unet_motion_model.py
+++ b/src/diffusers/models/unets/unet_motion_model.py
@@ -22,7 +22,7 @@
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
-from ...utils import BaseOutput, deprecate, is_torch_version, logging
+from ...utils import BaseOutput, deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import BasicTransformerBlock
from ..attention_processor import (
@@ -323,26 +323,8 @@ def forward(
blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- use_reentrant=False,
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
-
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
@@ -513,24 +495,8 @@ def forward(
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
@@ -543,10 +509,7 @@ def custom_forward(*inputs):
return_dict=False,
)[0]
- hidden_states = motion_module(
- hidden_states,
- num_frames=num_frames,
- )
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
@@ -732,24 +695,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
@@ -762,10 +709,7 @@ def custom_forward(*inputs):
return_dict=False,
)[0]
- hidden_states = motion_module(
- hidden_states,
- num_frames=num_frames,
- )
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -895,25 +839,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- use_reentrant=False,
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
@@ -1079,35 +1006,13 @@ def forward(
return_dict=False,
)[0]
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(motion_module),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ motion_module, hidden_states, None, None, None, num_frames, None
)
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
- hidden_states = motion_module(
- hidden_states,
- num_frames=num_frames,
- )
+ hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
return hidden_states
@@ -1301,6 +1206,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
"""
_supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
@@ -1965,10 +1871,6 @@ def set_default_attn_processor(self) -> None:
self.set_attn_processor(processor)
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
- module.gradient_checkpointing = value
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -2114,10 +2016,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -2156,7 +2059,7 @@ def forward(
aug_emb = self.add_embedding(add_embeds)
emb = emb if aug_emb is None else emb + aug_emb
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
@@ -2165,7 +2068,10 @@ def forward(
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds)
- image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
+ image_embeds = [
+ image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
+ for image_embed in image_embeds
+ ]
encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 2. pre-process
diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py
index 9fb975bc32d9..059a6e807c8e 100644
--- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py
+++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py
@@ -320,10 +320,6 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
@@ -382,16 +378,31 @@ def forward(
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
"""
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -420,9 +431,11 @@ def forward(
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
- emb = emb.repeat_interleave(num_frames, dim=0)
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
+ num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
+ )
# 2. pre-process
sample = self.conv_in(sample)
@@ -457,15 +470,23 @@ def forward(
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
image_only_indicator=image_only_indicator,
)
else:
@@ -473,6 +494,7 @@ def forward(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
image_only_indicator=image_only_indicator,
)
diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py
index 7deea9a714d4..f57754435fdc 100644
--- a/src/diffusers/models/unets/unet_stable_cascade.py
+++ b/src/diffusers/models/unets/unet_stable_cascade.py
@@ -387,9 +387,6 @@ def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=Tru
self.gradient_checkpointing = False
- def _set_gradient_checkpointing(self, value=False):
- self.gradient_checkpointing = value
-
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight)
@@ -455,30 +452,19 @@ def _down_encode(self, x, r_embed, clip):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, SDCascadeResBlock):
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
+ x = self._gradient_checkpointing_func(block, x)
elif isinstance(block, SDCascadeAttnBlock):
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, clip, use_reentrant=False
- )
+ x = self._gradient_checkpointing_func(block, x, clip)
elif isinstance(block, SDCascadeTimestepBlock):
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, r_embed, use_reentrant=False
- )
+ x = self._gradient_checkpointing_func(block, x, r_embed)
else:
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False)
+ x = self._gradient_checkpointing_func(block)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
@@ -504,14 +490,7 @@ def _up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
@@ -523,19 +502,13 @@ def custom_forward(*inputs):
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = x.to(orig_type)
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, skip, use_reentrant=False
- )
+ x = self._gradient_checkpointing_func(block, x, skip)
elif isinstance(block, SDCascadeAttnBlock):
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, clip, use_reentrant=False
- )
+ x = self._gradient_checkpointing_func(block, x, clip)
elif isinstance(block, SDCascadeTimestepBlock):
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, r_embed, use_reentrant=False
- )
+ x = self._gradient_checkpointing_func(block, x, r_embed)
else:
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
+ x = self._gradient_checkpointing_func(block, x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py
index 8a379bf5f9c3..94b39c84f055 100644
--- a/src/diffusers/models/unets/uvit_2d.py
+++ b/src/diffusers/models/unets/uvit_2d.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -148,9 +148,6 @@ def __init__(
self.gradient_checkpointing = False
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
- pass
-
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
@@ -181,7 +178,7 @@ def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds
hidden_states = self.project_to_hidden(hidden_states)
for layer in self.transformer_layers:
- if self.training and self.gradient_checkpointing:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
def layer_(*args):
return checkpoint(layer, *args)
diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py
index cf07e45b0c5c..af04ae4b93cf 100644
--- a/src/diffusers/models/upsampling.py
+++ b/src/diffusers/models/upsampling.py
@@ -165,6 +165,14 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if self.interpolate:
+ # upsample_nearest_nhwc also fails when the number of output elements is large
+ # https://github.com/pytorch/pytorch/issues/141831
+ scale_factor = (
+ 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])])
+ )
+ if hidden_states.numel() * scale_factor > pow(2, 31):
+ hidden_states = hidden_states.contiguous()
+
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py
index f20bd94edffa..e0b3576e4426 100644
--- a/src/diffusers/optimization.py
+++ b/src/diffusers/optimization.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -258,7 +258,7 @@ def get_polynomial_decay_schedule_with_warmup(
lr_init = optimizer.defaults["lr"]
if not (lr_init > lr_end):
- raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
+ raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 7366520f4692..b901d42d9cf7 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -116,6 +116,7 @@
"VersatileDiffusionTextToImagePipeline",
]
)
+ _import_structure["allegro"] = ["AllegroPipeline"]
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
@@ -126,12 +127,18 @@
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["flux"] = [
+ "FluxControlPipeline",
+ "FluxControlInpaintPipeline",
+ "FluxControlImg2ImgPipeline",
"FluxControlNetPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
+ "FluxFillPipeline",
+ "FluxPriorReduxPipeline",
+ "ReduxImageEncoder",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -147,6 +154,8 @@
"CogVideoXFunControlPipeline",
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
+ _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
+ _import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -156,6 +165,9 @@
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
+ "StableDiffusionXLControlNetUnionPipeline",
+ "StableDiffusionXLControlNetUnionInpaintPipeline",
+ "StableDiffusionXLControlNetUnionImg2ImgPipeline",
]
)
_import_structure["pag"].extend(
@@ -165,8 +177,10 @@
"KolorsPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
+ "StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPAGImg2ImgPipeline",
+ "StableDiffusionPAGInpaintPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
@@ -174,6 +188,7 @@
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
"PixArtSigmaPAGPipeline",
+ "SanaPAGPipeline",
]
)
_import_structure["controlnet_xs"].extend(
@@ -201,7 +216,17 @@
"IFPipeline",
"IFSuperResolutionPipeline",
]
+ _import_structure["easyanimate"] = [
+ "EasyAnimatePipeline",
+ "EasyAnimateInpaintPipeline",
+ "EasyAnimateControlPipeline",
+ ]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
+ _import_structure["hunyuan_video"] = [
+ "HunyuanVideoPipeline",
+ "HunyuanSkyreelsImageToVideoPipeline",
+ "HunyuanVideoImageToVideoPipeline",
+ ]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
@@ -239,17 +264,23 @@
]
)
_import_structure["latte"] = ["LattePipeline"]
- _import_structure["lumina"] = ["LuminaText2ImgPipeline"]
+ _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
+ _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
+ _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
+ "MarigoldIntrinsicsPipeline",
"MarigoldNormalsPipeline",
]
)
+ _import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
+ _import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
+ _import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_audio"] = [
@@ -325,6 +356,7 @@
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
+ _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -454,6 +486,7 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import *
else:
+ from .allegro import AllegroPipeline
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import (
AnimateDiffControlNetPipeline,
@@ -478,6 +511,8 @@
CogVideoXVideoToVideoPipeline,
)
from .cogview3 import CogView3PlusPipeline
+ from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
+ from .consisid import ConsisIDPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
@@ -486,6 +521,9 @@
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
+ StableDiffusionXLControlNetUnionInpaintPipeline,
+ StableDiffusionXLControlNetUnionPipeline,
)
from .controlnet_hunyuandit import (
HunyuanDiTControlNetPipeline,
@@ -517,13 +555,29 @@
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
+ from .easyanimate import (
+ EasyAnimateControlPipeline,
+ EasyAnimateInpaintPipeline,
+ EasyAnimatePipeline,
+ )
from .flux import (
+ FluxControlImg2ImgPipeline,
+ FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
+ FluxControlPipeline,
+ FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
+ FluxPriorReduxPipeline,
+ ReduxImageEncoder,
+ )
+ from .hunyuan_video import (
+ HunyuanSkyreelsImageToVideoPipeline,
+ HunyuanVideoImageToVideoPipeline,
+ HunyuanVideoPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
@@ -564,21 +618,29 @@
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
- from .lumina import LuminaText2ImgPipeline
+ from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
+ from .lumina import LuminaPipeline, LuminaText2ImgPipeline
+ from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
MarigoldDepthPipeline,
+ MarigoldIntrinsicsPipeline,
MarigoldNormalsPipeline,
)
+ from .mochi import MochiPipeline
from .musicldm import MusicLDMPipeline
+ from .omnigen import OmniGenPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
KolorsPAGPipeline,
PixArtSigmaPAGPipeline,
+ SanaPAGPipeline,
+ StableDiffusion3PAGImg2ImgPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGImg2ImgPipeline,
+ StableDiffusionPAGInpaintPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
@@ -589,6 +651,7 @@
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
+ from .sana import SanaPipeline, SanaSprintPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
@@ -646,6 +709,7 @@
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
+ from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
diff --git a/src/diffusers/pipelines/allegro/__init__.py b/src/diffusers/pipelines/allegro/__init__.py
new file mode 100644
index 000000000000..2162b825e0a2
--- /dev/null
+++ b/src/diffusers/pipelines/allegro/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_allegro"] = ["AllegroPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_allegro import AllegroPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py
new file mode 100644
index 000000000000..cb36a7a672de
--- /dev/null
+++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py
@@ -0,0 +1,958 @@
+# Copyright 2024 The RhymesAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import math
+import re
+import urllib.parse as ul
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro
+from ...models.embeddings import get_3d_rotary_pos_embed_allegro
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ BACKENDS_MAPPING,
+ deprecate,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import AllegroPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__)
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import AutoencoderKLAllegro, AllegroPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda")
+ >>> pipe.enable_vae_tiling()
+
+ >>> prompt = (
+ ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
+ ... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this "
+ ... "location might be a popular spot for docking fishing boats."
+ ... )
+ >>> video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=15)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class AllegroPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using Allegro.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AllegroAutoEncoderKL3D`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`AllegroTransformer3DModel`]):
+ A text conditioned `AllegroTransformer3DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ bad_punct_regex = re.compile(
+ r"["
+ + "#®•©™&@·º½¾¿¡§~"
+ + r"\)"
+ + r"\("
+ + r"\]"
+ + r"\["
+ + r"\}"
+ + r"\{"
+ + r"\|"
+ + "\\"
+ + r"\/"
+ + r"\*"
+ + r"]{1,}"
+ ) # noqa
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLAllegro,
+ transformer: AllegroTransformer3DModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->512, num_images_per_prompt->num_videos_per_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 512,
+ **kwargs,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 512): Maximum sequence length to use for the prompt.
+ """
+
+ if "mask_feature" in kwargs:
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+
+ if device is None:
+ device = self._execution_device
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because T5 can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ negative_prompt_attention_mask = uncond_input.attention_mask
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_videos_per_prompt)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ num_frames,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if num_frames <= 0:
+ raise ValueError(f"`num_frames` have to be positive but is {num_frames}.")
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if num_frames % 2 == 0:
+ num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal)
+ else:
+ num_frames = math.ceil((num_frames - 1) / self.vae_scale_factor_temporal) + 1
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ latents = 1 / self.vae.config.scaling_factor * latents
+ frames = self.vae.decode(latents).sample
+ frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width]
+ return frames
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ batch_size: int,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ):
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ start, stop = (0, 0), (grid_height, grid_width)
+ freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=(start, stop),
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ interpolation_scale=(
+ self.transformer.config.interpolation_scale_t,
+ self.transformer.config.interpolation_scale_h,
+ self.transformer.config.interpolation_scale_w,
+ ),
+ device=device,
+ )
+
+ grid_t = grid_t.to(dtype=torch.long)
+ grid_h = grid_h.to(dtype=torch.long)
+ grid_w = grid_w.to(dtype=torch.long)
+
+ pos = torch.cartesian_prod(grid_t, grid_h, grid_w)
+ pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous()
+ grid_t, grid_h, grid_w = pos
+
+ return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 100,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ num_frames: Optional[int] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ clean_caption: bool = True,
+ max_sequence_length: int = 512,
+ ) -> Union[AllegroPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
+ usually at the expense of lower video quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ num_frames: (`int`, *optional*, defaults to 88):
+ The number controls the generated video frames.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated video.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate video. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ max_sequence_length (`int` defaults to `512`):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated videos.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ num_frames = num_frames or self.transformer.config.sample_frames * self.vae_scale_factor_temporal
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+
+ self.check_inputs(
+ prompt,
+ num_frames,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+ self._guidance_scale = guidance_scale
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+ if prompt_embeds.ndim == 3:
+ prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare rotary embeddings
+ image_rotary_emb = self._prepare_rotary_positional_embeddings(
+ batch_size, height, width, latents.size(2), device
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ video = self.decode_latents(latents)
+ video = video[:, :, :num_frames, :height, :width]
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return AllegroPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/allegro/pipeline_output.py b/src/diffusers/pipelines/allegro/pipeline_output.py
new file mode 100644
index 000000000000..6a721783ca86
--- /dev/null
+++ b/src/diffusers/pipelines/allegro/pipeline_output.py
@@ -0,0 +1,23 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class AllegroPipelineOutput(BaseOutput):
+ r"""
+ Output class for Allegro pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
diff --git a/src/diffusers/pipelines/amused/pipeline_amused.py b/src/diffusers/pipelines/amused/pipeline_amused.py
index a8c24b0aeecc..12f7dc7c59d4 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused.py
@@ -20,10 +20,18 @@
from ...image_processor import VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
-from ...utils import replace_example_docstring
+from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -66,7 +74,9 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+ self.vae_scale_factor = (
+ 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
+ )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@@ -297,6 +307,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
output = latents
else:
diff --git a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
index c74275b414d4..7ac05b39c3a8 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
@@ -20,10 +20,18 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
-from ...utils import replace_example_docstring
+from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -81,7 +89,9 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+ self.vae_scale_factor = (
+ 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
+ )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@@ -323,6 +333,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
output = latents
else:
diff --git a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
index 24801e0ef977..d908c32745c2 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
@@ -21,10 +21,18 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
-from ...utils import replace_example_docstring
+from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -89,7 +97,9 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+ self.vae_scale_factor = (
+ 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
+ )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
@@ -354,6 +364,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
output = latents
else:
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index cb6f50f43c4f..d3ad5cc13ce3 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -19,7 +19,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
-from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
@@ -34,6 +34,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -47,8 +48,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -74,6 +83,7 @@ class AnimateDiffPipeline(
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for text-to-video generation.
@@ -139,7 +149,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -844,6 +854,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
index 5357d6d5b8d9..db546398643b 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
@@ -20,23 +20,37 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
-from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import (
+ AutoencoderKL,
+ ControlNetModel,
+ ImageProjection,
+ MultiControlNetModel,
+ UNet2DConditionModel,
+ UNetMotionModel,
+)
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
-from ..controlnet.multicontrolnet import MultiControlNetModel
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -111,6 +125,7 @@ class AnimateDiffControlNetPipeline(
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for text-to-video generation with ControlNet guidance.
@@ -174,7 +189,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_video_processor = VideoProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1084,6 +1099,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
index 6016917537b9..958eb5fb5134 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
@@ -48,6 +48,7 @@
)
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -60,8 +61,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -307,10 +316,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt
def encode_prompt(
@@ -438,7 +451,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -497,8 +512,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1261,6 +1278,9 @@ def __call__(
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
index 8b037cdc34fb..8c51fddcd5fc 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
@@ -22,14 +22,15 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
-from ...models.controlnet_sparsectrl import SparseControlNetModel
+from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -42,8 +43,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```python
@@ -127,6 +136,7 @@ class AnimateDiffSparseControlNetPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
@@ -188,7 +198,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -994,6 +1004,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 11. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
index 20e88075ed05..116397055272 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
@@ -19,7 +19,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
-from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
@@ -31,7 +31,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
@@ -40,8 +40,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -178,6 +186,7 @@ class AnimateDiffVideoToVideoPipeline(
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for video-to-video generation.
@@ -216,7 +225,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
- unet: UNet2DConditionModel,
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
@@ -243,7 +252,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
def encode_prompt(
@@ -662,12 +671,6 @@ def prepare_latents(
self.vae.to(dtype=torch.float32)
if isinstance(generator, list):
- if len(generator) != batch_size:
- raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
- )
-
init_latents = [
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
for i in range(batch_size)
@@ -1043,6 +1046,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 10. Post-processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
index 9a93f1d28d35..ce974094936a 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
@@ -20,8 +20,15 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
-from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import (
+ AutoencoderKL,
+ ControlNetModel,
+ ImageProjection,
+ MultiControlNetModel,
+ UNet2DConditionModel,
+ UNetMotionModel,
+)
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import (
@@ -32,18 +39,25 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
-from ..controlnet.multicontrolnet import MultiControlNetModel
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -190,6 +204,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for video-to-video generation with ControlNet guidance.
@@ -232,7 +247,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
- unet: UNet2DConditionModel,
+ unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: Union[
@@ -264,7 +279,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_video_processor = VideoProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -788,12 +803,6 @@ def prepare_latents(
self.vae.to(dtype=torch.float32)
if isinstance(generator, list):
- if len(generator) != batch_size:
- raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
- )
-
init_latents = [
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
for i in range(batch_size)
@@ -1325,6 +1334,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 11. Post-processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
index 105ca40f773f..14c6d44fc586 100644
--- a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
+++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
@@ -22,13 +22,21 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -94,7 +102,7 @@ def __init__(
scheduler=scheduler,
vocoder=vocoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def _encode_prompt(
self,
@@ -530,6 +538,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
mel_spectrogram = self.decode_latents(latents)
diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
index 2af3078f7412..00bed864ba34 100644
--- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
@@ -38,7 +38,7 @@
from ...models.transformers.transformer_2d import Transformer2DModel
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
-from ...utils import BaseOutput, is_torch_version, logging
+from ...utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -673,11 +673,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
sample: torch.Tensor,
@@ -768,10 +763,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -1112,24 +1108,8 @@ def forward(
)
for i in range(num_layers):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.resnets[i]),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
@@ -1140,8 +1120,8 @@ def custom_forward(*inputs):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
+ hidden_states = self._gradient_checkpointing_func(
+ self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
@@ -1149,7 +1129,6 @@ def custom_forward(*inputs):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
- **ckpt_kwargs,
)[0]
else:
hidden_states = self.resnets[i](hidden_states, temb)
@@ -1290,18 +1269,7 @@ def forward(
)
for i in range(len(self.resnets[1:])):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
@@ -1312,8 +1280,8 @@ def custom_forward(*inputs):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
+ hidden_states = self._gradient_checkpointing_func(
+ self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
@@ -1321,14 +1289,8 @@ def custom_forward(*inputs):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
- **ckpt_kwargs,
)[0]
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.resnets[i + 1]),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
else:
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
@@ -1464,24 +1426,8 @@ def forward(
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.resnets[i]),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
if cross_attention_dim is not None and idx <= 1:
forward_encoder_hidden_states = encoder_hidden_states
@@ -1492,8 +1438,8 @@ def custom_forward(*inputs):
else:
forward_encoder_hidden_states = None
forward_encoder_attention_mask = None
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
+ hidden_states = self._gradient_checkpointing_func(
+ self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
@@ -1501,7 +1447,6 @@ def custom_forward(*inputs):
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
- **ckpt_kwargs,
)[0]
else:
hidden_states = self.resnets[i](hidden_states, temb)
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index b45771d7de74..b8b5d07af529 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -48,8 +48,20 @@
if is_librosa_available():
import librosa
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -207,7 +219,7 @@ def __init__(
scheduler=scheduler,
vocoder=vocoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
def enable_vae_slicing(self):
@@ -225,7 +237,7 @@ def disable_vae_slicing(self):
"""
self.vae.disable_slicing()
- def enable_model_cpu_offload(self, gpu_id=0):
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
@@ -237,11 +249,23 @@ def enable_model_cpu_offload(self, gpu_id=0):
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
- device = torch.device(f"cuda:{gpu_id}")
+ torch_device = torch.device(device)
+ device_index = torch_device.index
+
+ if gpu_id is not None and device_index is not None:
+ raise ValueError(
+ f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
+ f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
+ )
+
+ device_type = torch_device.type
+ device = torch.device(f"{device_type}:{gpu_id or torch_device.index}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+ device_mod = getattr(torch, device.type, None)
+ if hasattr(device_mod, "empty_cache") and device_mod.is_available():
+ device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
model_sequence = [
self.text_encoder.text_model,
@@ -1033,6 +1057,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.maybe_free_model_hooks()
# 8. Post-processing
diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
index 58eaf6b46d0a..ea60e66d2db9 100644
--- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
+++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
@@ -12,20 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
-from typing import List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5Tokenizer, UMT5EncoderModel
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -124,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline):
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ ]
def __init__(
self,
@@ -139,9 +151,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def check_inputs(
@@ -154,10 +164,19 @@ def check_inputs(
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
):
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -380,6 +399,14 @@ def upcast_vae(self):
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -387,7 +414,6 @@ def __call__(
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
@@ -402,6 +428,10 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
@@ -424,10 +454,6 @@ def __call__(
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -460,6 +486,15 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
@@ -481,8 +516,11 @@ def __call__(
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
+ self._guidance_scale = guidance_scale
+
# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -522,9 +560,7 @@ def __call__(
# 4. Prepare timesteps
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
- )
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
@@ -541,6 +577,7 @@ def __call__(
# 6. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
@@ -567,10 +604,22 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
image = latents
else:
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index 0214d7dd6f3c..6a5f6098b6fb 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,9 +18,11 @@
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
+from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
+from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
@@ -28,12 +30,22 @@
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
+ StableDiffusionXLControlNetUnionInpaintPipeline,
+ StableDiffusionXLControlNetUnionPipeline,
+)
+from .controlnet_sd3 import (
+ StableDiffusion3ControlNetInpaintingPipeline,
+ StableDiffusion3ControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import (
+ FluxControlImg2ImgPipeline,
+ FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
FluxControlNetPipeline,
+ FluxControlPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
@@ -57,14 +69,18 @@
)
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
-from .lumina import LuminaText2ImgPipeline
+from .lumina import LuminaPipeline
+from .lumina2 import Lumina2Pipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
+ SanaPAGPipeline,
+ StableDiffusion3PAGImg2ImgPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGImg2ImgPipeline,
+ StableDiffusionPAGInpaintPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGImg2ImgPipeline,
StableDiffusionXLControlNetPAGPipeline,
@@ -73,6 +89,7 @@
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
+from .sana import SanaPipeline
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
@@ -106,11 +123,15 @@
("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
+ ("stable-diffusion-3-controlnet", StableDiffusion3ControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
("cascade", StableCascadeCombinedPipeline),
("lcm", LatentConsistencyModelPipeline),
("pixart-alpha", PixArtAlphaPipeline),
("pixart-sigma", PixArtSigmaPipeline),
+ ("sana", SanaPipeline),
+ ("sana-pag", SanaPAGPipeline),
("stable-diffusion-pag", StableDiffusionPAGPipeline),
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
@@ -118,9 +139,13 @@
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
+ ("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
- ("lumina", LuminaText2ImgPipeline),
+ ("lumina", LuminaPipeline),
+ ("lumina2", Lumina2Pipeline),
("cogview3", CogView3PlusPipeline),
+ ("cogview4", CogView4Pipeline),
+ ("cogview4-control", CogView4ControlPipeline),
]
)
@@ -129,6 +154,7 @@
("stable-diffusion", StableDiffusionImg2ImgPipeline),
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
("stable-diffusion-3", StableDiffusion3Img2ImgPipeline),
+ ("stable-diffusion-3-pag", StableDiffusion3PAGImg2ImgPipeline),
("if", IFImg2ImgPipeline),
("kandinsky", KandinskyImg2ImgCombinedPipeline),
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
@@ -136,11 +162,13 @@
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
+ ("flux-control", FluxControlImg2ImgPipeline),
]
)
@@ -155,9 +183,13 @@
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
+ ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
+ ("stable-diffusion-3-controlnet", StableDiffusion3ControlNetInpaintingPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
("flux-controlnet", FluxControlNetInpaintPipeline),
+ ("flux-control", FluxControlInpaintPipeline),
+ ("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
]
)
@@ -276,7 +308,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -368,7 +400,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
```py
>>> from diffusers import AutoPipelineForText2Image
- >>> pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipeline = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0]
```
"""
@@ -390,13 +422,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
+ if "ControlPipeline" in orig_class_name:
+ to_replace = "ControlPipeline"
+ else:
+ to_replace = "Pipeline"
if "controlnet" in kwargs:
- orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
+ else:
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
- orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
@@ -424,7 +463,7 @@ def from_pipe(cls, pipeline, **kwargs):
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
>>> pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", requires_safety_checker=False
... )
>>> pipe_t2i = AutoPipelineForText2Image.from_pipe(pipe_i2i)
@@ -504,7 +543,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in text_2_image_kwargs
}
- missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys())
+ missing_modules = (
+ set(expected_modules) - set(text_2_image_cls._optional_components) - set(text_2_image_kwargs.keys())
+ )
if len(missing_modules) > 0:
raise ValueError(
@@ -563,7 +604,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -655,7 +696,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
```py
>>> from diffusers import AutoPipelineForImage2Image
- >>> pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipeline = AutoPipelineForImage2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> image = pipeline(prompt, image).images[0]
```
"""
@@ -680,16 +721,28 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
- to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
+ if "Img2Img" in orig_class_name:
+ to_replace = "Img2ImgPipeline"
+ elif "ControlPipeline" in orig_class_name:
+ to_replace = "ControlPipeline"
+ else:
+ to_replace = "Pipeline"
if "controlnet" in kwargs:
- orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
+ else:
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
+ if to_replace == "ControlPipeline":
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
+
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -718,7 +771,7 @@ def from_pipe(cls, pipeline, **kwargs):
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
>>> pipe_t2i = AutoPipelineForText2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", requires_safety_checker=False
... )
>>> pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i)
@@ -802,7 +855,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in image_2_image_kwargs
}
- missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys())
+ missing_modules = (
+ set(expected_modules) - set(image_2_image_cls._optional_components) - set(image_2_image_kwargs.keys())
+ )
if len(missing_modules) > 0:
raise ValueError(
@@ -860,7 +915,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -952,7 +1007,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
```py
>>> from diffusers import AutoPipelineForInpainting
- >>> pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipeline = AutoPipelineForInpainting.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
```
"""
@@ -977,15 +1032,26 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
- to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
+ if "Inpaint" in orig_class_name:
+ to_replace = "InpaintPipeline"
+ elif "ControlPipeline" in orig_class_name:
+ to_replace = "ControlPipeline"
+ else:
+ to_replace = "Pipeline"
if "controlnet" in kwargs:
- orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
+ else:
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
+ if to_replace == "ControlPipeline":
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -1094,7 +1160,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in inpainting_kwargs
}
- missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys())
+ missing_modules = (
+ set(expected_modules) - set(inpainting_cls._optional_components) - set(inpainting_kwargs.keys())
+ )
if len(missing_modules) > 0:
raise ValueError(
diff --git a/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py b/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py
index d92a07669059..e45f431d0b9d 100644
--- a/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py
+++ b/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
index 1be4761a9987..d2408417f590 100644
--- a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
+++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
@@ -167,26 +167,23 @@ def forward(
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
- if getattr(self.config, "gradient_checkpointing", False) and self.training:
+ if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled():
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs, past_key_value, output_attentions, query_length)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(layer_module),
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ query_length,
)
else:
layer_outputs = layer_module(
diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
index ff23247b5f81..cbd8bef67945 100644
--- a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
+++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
@@ -20,6 +20,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -30,8 +31,16 @@
from .modeling_ctx_clip import ContextCLIPTextModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -336,6 +345,9 @@ def __call__(
latents,
)["prev_sample"]
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index 9cb042c9e80c..99ae9025cd3e 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -26,12 +26,19 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -183,14 +190,12 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scaling_factor_image = (
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -442,21 +447,39 @@ def _prepare_rotary_positional_embeddings(
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- grid_crops_coords = get_resize_crop_region_for_grid(
- (grid_height, grid_width), base_size_width, base_size_height
- )
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=self.transformer.config.attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=num_frames,
- )
+ p = self.transformer.config.patch_size
+ p_t = self.transformer.config.patch_size_t
+
+ base_size_width = self.transformer.config.sample_width // p
+ base_size_height = self.transformer.config.sample_height // p
+
+ if p_t is None:
+ # CogVideoX 1.0
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ device=device,
+ )
+ else:
+ # CogVideoX 1.5
+ base_num_frames = (num_frames + p_t - 1) // p_t
+
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=None,
+ grid_size=(grid_height, grid_width),
+ temporal_size=base_num_frames,
+ grid_type="slice",
+ max_size=(base_size_height, base_size_width),
+ device=device,
+ )
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
@@ -471,6 +494,10 @@ def num_timesteps(self):
def attention_kwargs(self):
return self._attention_kwargs
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
@property
def interrupt(self):
return self._interrupt
@@ -481,9 +508,9 @@ def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
- height: int = 480,
- width: int = 720,
- num_frames: int = 49,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
@@ -583,14 +610,13 @@ def __call__(
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
- if num_frames > 49:
- raise ValueError(
- "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
- )
-
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_frames = num_frames or self.transformer.config.sample_frames
+
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
@@ -605,6 +631,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -640,7 +667,16 @@ def __call__(
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
- # 5. Prepare latents.
+ # 5. Prepare latents
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
+ patch_size_t = self.transformer.config.patch_size_t
+ additional_frames = 0
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
+ additional_frames = patch_size_t - latent_frames % patch_size_t
+ num_frames += additional_frames * self.vae_scale_factor_temporal
+
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
@@ -674,6 +710,7 @@ def __call__(
if self.interrupt:
continue
+ self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -729,7 +766,14 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
if not output_type == "latent":
+ # Discard any padding frames that were added for CogVideoX 1.5
+ latents = latents[:, additional_frames:]
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
index 3655075bd519..e37574ec9cb2 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -27,12 +27,19 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -190,14 +197,12 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scaling_factor_image = (
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -488,21 +493,39 @@ def _prepare_rotary_positional_embeddings(
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- grid_crops_coords = get_resize_crop_region_for_grid(
- (grid_height, grid_width), base_size_width, base_size_height
- )
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=self.transformer.config.attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=num_frames,
- )
+ p = self.transformer.config.patch_size
+ p_t = self.transformer.config.patch_size_t
+
+ base_size_width = self.transformer.config.sample_width // p
+ base_size_height = self.transformer.config.sample_height // p
+
+ if p_t is None:
+ # CogVideoX 1.0
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ device=device,
+ )
+ else:
+ # CogVideoX 1.5
+ base_num_frames = (num_frames + p_t - 1) // p_t
+
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=None,
+ grid_size=(grid_height, grid_width),
+ temporal_size=base_num_frames,
+ grid_type="slice",
+ max_size=(base_size_height, base_size_width),
+ device=device,
+ )
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
@@ -517,6 +540,10 @@ def num_timesteps(self):
def attention_kwargs(self):
return self._attention_kwargs
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
@property
def interrupt(self):
return self._interrupt
@@ -528,8 +555,8 @@ def __call__(
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
control_video: Optional[List[Image.Image]] = None,
- height: int = 480,
- width: int = 720,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
@@ -634,6 +661,13 @@ def __call__(
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ if control_video is not None and isinstance(control_video[0], Image.Image):
+ control_video = [control_video]
+
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
+
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
@@ -650,6 +684,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -660,9 +695,6 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]
- if control_video is not None and isinstance(control_video[0], Image.Image):
- control_video = [control_video]
-
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
@@ -688,9 +720,18 @@ def __call__(
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
- # 5. Prepare latents.
+ # 5. Prepare latents
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
+ patch_size_t = self.transformer.config.patch_size_t
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
+ raise ValueError(
+ f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
+ f"contains {latent_frames=}, which is not divisible."
+ )
+
latent_channels = self.transformer.config.in_channels // 2
- num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
@@ -730,6 +771,7 @@ def __call__(
if self.interrupt:
continue
+ self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -779,6 +821,11 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index 783dae569bec..59d7c4cad547 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -29,6 +29,7 @@
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -37,6 +38,13 @@
from .pipeline_output import CogVideoXPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -203,14 +211,12 @@ def __init__(
scheduler=scheduler,
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scaling_factor_image = (
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -367,6 +373,10 @@ def prepare_latents(
width // self.vae_scale_factor_spatial,
)
+ # For CogVideoX1.5, the latent should add 1 for padding (Not use)
+ if self.transformer.config.patch_size_t is not None:
+ shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
+
image = image.unsqueeze(2) # [B, C, F, H, W]
if isinstance(generator, list):
@@ -377,7 +387,13 @@ def prepare_latents(
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
- image_latents = self.vae_scaling_factor_image * image_latents
+
+ if not self.vae.config.invert_scale_latents:
+ image_latents = self.vae_scaling_factor_image * image_latents
+ else:
+ # This is awkward but required because the CogVideoX team forgot to multiply the
+ # scaling factor during training :)
+ image_latents = 1 / self.vae_scaling_factor_image * image_latents
padding_shape = (
batch_size,
@@ -386,9 +402,15 @@ def prepare_latents(
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
+
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
image_latents = torch.cat([image_latents, latent_padding], dim=1)
+ # Select the first frame along the second dimension
+ if self.transformer.config.patch_size_t is not None:
+ first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
+ image_latents = torch.cat([first_frame, image_latents], dim=1)
+
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
@@ -522,21 +544,39 @@ def _prepare_rotary_positional_embeddings(
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- grid_crops_coords = get_resize_crop_region_for_grid(
- (grid_height, grid_width), base_size_width, base_size_height
- )
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=self.transformer.config.attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=num_frames,
- )
+ p = self.transformer.config.patch_size
+ p_t = self.transformer.config.patch_size_t
+
+ base_size_width = self.transformer.config.sample_width // p
+ base_size_height = self.transformer.config.sample_height // p
+
+ if p_t is None:
+ # CogVideoX 1.0
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ device=device,
+ )
+ else:
+ # CogVideoX 1.5
+ base_num_frames = (num_frames + p_t - 1) // p_t
+
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=None,
+ grid_size=(grid_height, grid_width),
+ temporal_size=base_num_frames,
+ grid_type="slice",
+ max_size=(base_size_height, base_size_width),
+ device=device,
+ )
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
@@ -551,6 +591,10 @@ def num_timesteps(self):
def attention_kwargs(self):
return self._attention_kwargs
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
@property
def interrupt(self):
return self._interrupt
@@ -562,8 +606,8 @@ def __call__(
image: PipelineImageInput,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
- height: int = 480,
- width: int = 720,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
num_frames: int = 49,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
@@ -666,14 +710,13 @@ def __call__(
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
- if num_frames > 49:
- raise ValueError(
- "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
- )
-
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_frames = num_frames or self.transformer.config.sample_frames
+
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
@@ -689,6 +732,7 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
+ self._current_timestep = None
self._attention_kwargs = attention_kwargs
self._interrupt = False
@@ -726,6 +770,15 @@ def __call__(
self._num_timesteps = len(timesteps)
# 5. Prepare latents
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
+ patch_size_t = self.transformer.config.patch_size_t
+ additional_frames = 0
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
+ additional_frames = patch_size_t - latent_frames % patch_size_t
+ num_frames += additional_frames * self.vae_scale_factor_temporal
+
image = self.video_processor.preprocess(image, height=height, width=width).to(
device, dtype=prompt_embeds.dtype
)
@@ -754,6 +807,9 @@ def __call__(
else None
)
+ # 8. Create ofs embeds if required
+ ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
+
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -764,6 +820,7 @@ def __call__(
if self.interrupt:
continue
+ self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -778,6 +835,7 @@ def __call__(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
+ ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
@@ -822,7 +880,14 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
if not output_type == "latent":
+ # Discard any padding frames that were added for CogVideoX 1.5
+ latents = latents[:, additional_frames:]
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index e1e816eca16d..c4dc7e574f7e 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -27,12 +27,19 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -206,14 +213,12 @@ def __init__(
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scaling_factor_image = (
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -373,12 +378,6 @@ def prepare_latents(
if latents is None:
if isinstance(generator, list):
- if len(generator) != batch_size:
- raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
- )
-
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
@@ -518,21 +517,39 @@ def _prepare_rotary_positional_embeddings(
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
- grid_crops_coords = get_resize_crop_region_for_grid(
- (grid_height, grid_width), base_size_width, base_size_height
- )
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=self.transformer.config.attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=num_frames,
- )
+ p = self.transformer.config.patch_size
+ p_t = self.transformer.config.patch_size_t
+
+ base_size_width = self.transformer.config.sample_width // p
+ base_size_height = self.transformer.config.sample_height // p
+
+ if p_t is None:
+ # CogVideoX 1.0
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ device=device,
+ )
+ else:
+ # CogVideoX 1.5
+ base_num_frames = (num_frames + p_t - 1) // p_t
+
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=None,
+ grid_size=(grid_height, grid_width),
+ temporal_size=base_num_frames,
+ grid_type="slice",
+ max_size=(base_size_height, base_size_width),
+ device=device,
+ )
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
@@ -547,6 +564,10 @@ def num_timesteps(self):
def attention_kwargs(self):
return self._attention_kwargs
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
@property
def interrupt(self):
return self._interrupt
@@ -558,8 +579,8 @@ def __call__(
video: List[Image.Image] = None,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
- height: int = 480,
- width: int = 720,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
strength: float = 0.8,
@@ -662,6 +683,10 @@ def __call__(
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_frames = len(video) if latents is None else latents.size(1)
+
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
@@ -679,6 +704,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -717,6 +743,16 @@ def __call__(
self._num_timesteps = len(timesteps)
# 5. Prepare latents
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
+ patch_size_t = self.transformer.config.patch_size_t
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
+ raise ValueError(
+ f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
+ f"contains {latent_frames=}, which is not divisible."
+ )
+
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=prompt_embeds.dtype)
@@ -755,6 +791,7 @@ def __call__(
if self.interrupt:
continue
+ self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -810,6 +847,11 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 64fff61d2c32..0cd3943fbcd2 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -24,11 +24,18 @@
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView3PipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -38,7 +45,7 @@
>>> import torch
>>> from diffusers import CogView3PlusPipeline
- >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", torch_dtype=torch.bfloat16)
+ >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A photo of an astronaut riding a horse on mars"
@@ -153,9 +160,7 @@ def __init__(
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -656,6 +661,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/pipelines/cogview4/__init__.py b/src/diffusers/pipelines/cogview4/__init__.py
new file mode 100644
index 000000000000..6a365e17fee7
--- /dev/null
+++ b/src/diffusers/pipelines/cogview4/__init__.py
@@ -0,0 +1,49 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["CogView4PlusPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"]
+ _import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_cogview4 import CogView4Pipeline
+ from .pipeline_cogview4_control import CogView4ControlPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
new file mode 100644
index 000000000000..8550fa94f9e4
--- /dev/null
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -0,0 +1,684 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, GlmModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import VaeImageProcessor
+from ...loaders import CogView4LoraLoaderMixin
+from ...models import AutoencoderKL, CogView4Transformer2DModel
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from .pipeline_output import CogView4PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogView4Pipeline
+
+ >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ base_shift: float = 0.25,
+ max_shift: float = 0.75,
+) -> float:
+ m = (image_seq_len / base_seq_len) ** 0.5
+ mu = m * max_shift + base_shift
+ return mu
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+
+ if timesteps is not None and sigmas is not None:
+ if not accepts_timesteps and not accepts_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif timesteps is not None and sigmas is None:
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif timesteps is None and sigmas is not None:
+ if not accepts_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using CogView4.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`GLMModel`]):
+ Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
+ tokenizer (`PreTrainedTokenizer`):
+ Tokenizer of class
+ [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
+ transformer ([`CogView4Transformer2DModel`]):
+ A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: GlmModel,
+ vae: AutoencoderKL,
+ transformer: CogView4Transformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _get_glm_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 1024,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="longest", # not use max length
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+ current_length = text_input_ids.shape[1]
+ pad_length = (16 - (current_length % 16)) % 16
+ if pad_length > 0:
+ pad_ids = torch.full(
+ (text_input_ids.shape[0], pad_length),
+ fill_value=self.tokenizer.pad_token_id,
+ dtype=text_input_ids.dtype,
+ device=text_input_ids.device,
+ )
+ text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ max_sequence_length (`int`, defaults to `1024`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
+
+ seq_len = prompt_embeds.size(1)
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
+
+ seq_len = negative_prompt_embeds.size(1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ if latents is not None:
+ return latents.to(device)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 5.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 1024,
+ ) -> Union[CogView4PipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. If not provided, it is set to 1024.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. If not provided it is set to 1024.
+ num_inference_steps (`int`, *optional*, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to `1`):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `224`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = (height, width)
+
+ # Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ self.do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Prepare latents
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # Prepare additional timestep conditions
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
+
+ original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
+ target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
+ crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
+
+ # Prepare timesteps
+ image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
+ self.transformer.config.patch_size**2
+ )
+ timesteps = (
+ np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
+ if timesteps is None
+ else np.array(timesteps)
+ )
+ timesteps = timesteps.astype(np.int64).astype(np.float32)
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("base_shift", 0.25),
+ self.scheduler.config.get("max_shift", 0.75),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
+ )
+ self._num_timesteps = len(timesteps)
+
+ # Denoising loop
+ transformer_dtype = self.transformer.dtype
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred_cond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=negative_prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = noise_pred_cond
+
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return CogView4PipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
new file mode 100644
index 000000000000..7613bc3d0f40
--- /dev/null
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
@@ -0,0 +1,732 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, GlmModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...models import AutoencoderKL, CogView4Transformer2DModel
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from .pipeline_output import CogView4PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogView4ControlPipeline
+
+ >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
+ >>> control_image = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ ... )
+ >>> prompt = "A bird in space"
+ >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0]
+ >>> image.save("cogview4-control.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ base_shift: float = 0.25,
+ max_shift: float = 0.75,
+) -> float:
+ m = (image_seq_len / base_seq_len) ** 0.5
+ mu = m * max_shift + base_shift
+ return mu
+
+
+# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+
+ if timesteps is not None and sigmas is not None:
+ if not accepts_timesteps and not accepts_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif timesteps is not None and sigmas is None:
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif timesteps is None and sigmas is not None:
+ if not accepts_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CogView4ControlPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using CogView4.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`GLMModel`]):
+ Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
+ tokenizer (`PreTrainedTokenizer`):
+ Tokenizer of class
+ [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
+ transformer ([`CogView4Transformer2DModel`]):
+ A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: GlmModel,
+ vae: AutoencoderKL,
+ transformer: CogView4Transformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds
+ def _get_glm_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 1024,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="longest", # not use max length
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+ current_length = text_input_ids.shape[1]
+ pad_length = (16 - (current_length % 16)) % 16
+ if pad_length > 0:
+ pad_ids = torch.full(
+ (text_input_ids.shape[0], pad_length),
+ fill_value=self.tokenizer.pad_token_id,
+ dtype=text_input_ids.dtype,
+ device=text_input_ids.device,
+ )
+ text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ max_sequence_length (`int`, defaults to `1024`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype)
+
+ seq_len = prompt_embeds.size(1)
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype)
+
+ seq_len = negative_prompt_embeds.size(1)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ if latents is not None:
+ return latents.to(device)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 5.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 1024,
+ ) -> Union[CogView4PipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. If not provided, it is set to 1024.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. If not provided it is set to 1024.
+ num_inference_steps (`int`, *optional*, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to `1`):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain
+ tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `224`):
+ Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
+ Examples:
+
+ Returns:
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
+ [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = (height, width)
+
+ # Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ self.do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Prepare latents
+ latent_channels = self.transformer.config.in_channels // 2
+
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image.shape[-2:]
+
+ vae_shift_factor = 0
+
+ control_image = self.vae.encode(control_image).latent_dist.sample()
+ control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
+
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # Prepare additional timestep conditions
+ original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
+ target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
+ crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
+
+ original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
+ target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
+ crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
+
+ # Prepare timesteps
+ image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
+ self.transformer.config.patch_size**2
+ )
+
+ timesteps = (
+ np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps)
+ if timesteps is None
+ else np.array(timesteps)
+ )
+ timesteps = timesteps.astype(np.int64).astype(np.float32)
+ sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("base_shift", 0.25),
+ self.scheduler.config.get("max_shift", 0.75),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
+ )
+ self._num_timesteps = len(timesteps)
+ # Denoising loop
+ transformer_dtype = self.transformer.dtype
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred_cond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=negative_prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = noise_pred_cond
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return CogView4PipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/cogview4/pipeline_output.py b/src/diffusers/pipelines/cogview4/pipeline_output.py
new file mode 100644
index 000000000000..4efec1310845
--- /dev/null
+++ b/src/diffusers/pipelines/cogview4/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class CogView4PipelineOutput(BaseOutput):
+ """
+ Output class for CogView3 pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/consisid/__init__.py b/src/diffusers/pipelines/consisid/__init__.py
new file mode 100644
index 000000000000..5052e146f1df
--- /dev/null
+++ b/src/diffusers/pipelines/consisid/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_consisid"] = ["ConsisIDPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_consisid import ConsisIDPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/consisid/consisid_utils.py b/src/diffusers/pipelines/consisid/consisid_utils.py
new file mode 100644
index 000000000000..874b3d76149b
--- /dev/null
+++ b/src/diffusers/pipelines/consisid/consisid_utils.py
@@ -0,0 +1,357 @@
+import importlib.util
+import os
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image, ImageOps
+from torchvision.transforms import InterpolationMode
+from torchvision.transforms.functional import normalize, resize
+
+from ...utils import get_logger, load_image
+
+
+logger = get_logger(__name__)
+
+_insightface_available = importlib.util.find_spec("insightface") is not None
+_consisid_eva_clip_available = importlib.util.find_spec("consisid_eva_clip") is not None
+_facexlib_available = importlib.util.find_spec("facexlib") is not None
+
+if _insightface_available:
+ import insightface
+ from insightface.app import FaceAnalysis
+else:
+ raise ImportError("insightface is not available. Please install it using 'pip install insightface'.")
+
+if _consisid_eva_clip_available:
+ from consisid_eva_clip import create_model_and_transforms
+ from consisid_eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+else:
+ raise ImportError("consisid_eva_clip is not available. Please install it using 'pip install consisid_eva_clip'.")
+
+if _facexlib_available:
+ from facexlib.parsing import init_parsing_model
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+else:
+ raise ImportError("facexlib is not available. Please install it using 'pip install facexlib'.")
+
+
+def resize_numpy_image_long(image, resize_long_edge=768):
+ """
+ Resize the input image to a specified long edge while maintaining aspect ratio.
+
+ Args:
+ image (numpy.ndarray): Input image (H x W x C or H x W).
+ resize_long_edge (int): The target size for the long edge of the image. Default is 768.
+
+ Returns:
+ numpy.ndarray: Resized image with the long edge matching `resize_long_edge`, while maintaining the aspect
+ ratio.
+ """
+
+ h, w = image.shape[:2]
+ if max(h, w) <= resize_long_edge:
+ return image
+ k = resize_long_edge / max(h, w)
+ h = int(h * k)
+ w = int(w * k)
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
+ return image
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def to_gray(img):
+ """
+ Converts an RGB image to grayscale by applying the standard luminosity formula.
+
+ Args:
+ img (torch.Tensor): The input image tensor with shape (batch_size, channels, height, width).
+ The image is expected to be in RGB format (3 channels).
+
+ Returns:
+ torch.Tensor: The grayscale image tensor with shape (batch_size, 3, height, width).
+ The grayscale values are replicated across all three channels.
+ """
+ x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
+ x = x.repeat(1, 3, 1, 1)
+ return x
+
+
+def process_face_embeddings(
+ face_helper_1,
+ clip_vision_model,
+ face_helper_2,
+ eva_transform_mean,
+ eva_transform_std,
+ app,
+ device,
+ weight_dtype,
+ image,
+ original_id_image=None,
+ is_align_face=True,
+):
+ """
+ Process face embeddings from an image, extracting relevant features such as face embeddings, landmarks, and parsed
+ face features using a series of face detection and alignment tools.
+
+ Args:
+ face_helper_1: Face helper object (first helper) for alignment and landmark detection.
+ clip_vision_model: Pre-trained CLIP vision model used for feature extraction.
+ face_helper_2: Face helper object (second helper) for embedding extraction.
+ eva_transform_mean: Mean values for image normalization before passing to EVA model.
+ eva_transform_std: Standard deviation values for image normalization before passing to EVA model.
+ app: Application instance used for face detection.
+ device: Device (CPU or GPU) where the computations will be performed.
+ weight_dtype: Data type of the weights for precision (e.g., `torch.float32`).
+ image: Input image in RGB format with pixel values in the range [0, 255].
+ original_id_image: (Optional) Original image for feature extraction if `is_align_face` is False.
+ is_align_face: Boolean flag indicating whether face alignment should be performed.
+
+ Returns:
+ Tuple:
+ - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding
+ - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors.
+ - return_face_features_image_2: Processed face features image after normalization and parsing.
+ - face_kps: Keypoints of the face detected in the image.
+ """
+
+ face_helper_1.clean_all()
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+ # get antelopev2 embedding
+ face_info = app.get(image_bgr)
+ if len(face_info) > 0:
+ face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[
+ -1
+ ] # only use the maximum face
+ id_ante_embedding = face_info["embedding"] # (512,)
+ face_kps = face_info["kps"]
+ else:
+ id_ante_embedding = None
+ face_kps = None
+
+ # using facexlib to detect and align face
+ face_helper_1.read_image(image_bgr)
+ face_helper_1.get_face_landmarks_5(only_center_face=True)
+ if face_kps is None:
+ face_kps = face_helper_1.all_landmarks_5[0]
+ face_helper_1.align_warp_face()
+ if len(face_helper_1.cropped_faces) == 0:
+ raise RuntimeError("facexlib align face fail")
+ align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB
+
+ # incase insightface didn't detect face
+ if id_ante_embedding is None:
+ logger.warning("Failed to detect face using insightface. Extracting embedding with align face")
+ id_ante_embedding = face_helper_2.get_feat(align_face)
+
+ id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
+ if id_ante_embedding.ndim == 1:
+ id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512])
+
+ # parsing
+ if is_align_face:
+ input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
+ input = input.to(device)
+ parsing_out = face_helper_1.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
+ parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
+ bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
+ bg = sum(parsing_out == i for i in bg_label).bool()
+ white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512])
+ # only keep the face features
+ return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512])
+ return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512])
+ else:
+ original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR)
+ input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
+ input = input.to(device)
+ return_face_features_image = return_face_features_image_2 = input
+
+ # transform img before sending to eva-clip-vit
+ face_features_image = resize(
+ return_face_features_image, clip_vision_model.image_size, InterpolationMode.BICUBIC
+ ) # torch.Size([1, 3, 336, 336])
+ face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std)
+ id_cond_vit, id_vit_hidden = clip_vision_model(
+ face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False
+ ) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024]))
+ id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
+ id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
+
+ id_cond = torch.cat(
+ [id_ante_embedding, id_cond_vit], dim=-1
+ ) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
+
+ return (
+ id_cond,
+ id_vit_hidden,
+ return_face_features_image_2,
+ face_kps,
+ ) # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
+
+
+def process_face_embeddings_infer(
+ face_helper_1,
+ clip_vision_model,
+ face_helper_2,
+ eva_transform_mean,
+ eva_transform_std,
+ app,
+ device,
+ weight_dtype,
+ img_file_path,
+ is_align_face=True,
+):
+ """
+ Process face embeddings from an input image for inference, including alignment, feature extraction, and embedding
+ concatenation.
+
+ Args:
+ face_helper_1: Face helper object (first helper) for alignment and landmark detection.
+ clip_vision_model: Pre-trained CLIP vision model used for feature extraction.
+ face_helper_2: Face helper object (second helper) for embedding extraction.
+ eva_transform_mean: Mean values for image normalization before passing to EVA model.
+ eva_transform_std: Standard deviation values for image normalization before passing to EVA model.
+ app: Application instance used for face detection.
+ device: Device (CPU or GPU) where the computations will be performed.
+ weight_dtype: Data type of the weights for precision (e.g., `torch.float32`).
+ img_file_path: Path to the input image file (string) or a numpy array representing an image.
+ is_align_face: Boolean flag indicating whether face alignment should be performed (default: True).
+
+ Returns:
+ Tuple:
+ - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding.
+ - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors.
+ - image: Processed face image after feature extraction and alignment.
+ - face_kps: Keypoints of the face detected in the image.
+ """
+
+ # Load and preprocess the input image
+ if isinstance(img_file_path, str):
+ image = np.array(load_image(image=img_file_path).convert("RGB"))
+ else:
+ image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB"))
+
+ # Resize image to ensure the longer side is 1024 pixels
+ image = resize_numpy_image_long(image, 1024)
+ original_id_image = image
+
+ # Process the image to extract face embeddings and related features
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(
+ face_helper_1,
+ clip_vision_model,
+ face_helper_2,
+ eva_transform_mean,
+ eva_transform_std,
+ app,
+ device,
+ weight_dtype,
+ image,
+ original_id_image,
+ is_align_face,
+ )
+
+ # Convert the aligned cropped face image (torch tensor) to a numpy array
+ tensor = align_crop_face_image.cpu().detach()
+ tensor = tensor.squeeze()
+ tensor = tensor.permute(1, 2, 0)
+ tensor = tensor.numpy() * 255
+ tensor = tensor.astype(np.uint8)
+ image = ImageOps.exif_transpose(Image.fromarray(tensor))
+
+ return id_cond, id_vit_hidden, image, face_kps
+
+
+def prepare_face_models(model_path, device, dtype):
+ """
+ Prepare all face models for the facial recognition task.
+
+ Parameters:
+ - model_path: Path to the directory containing model files.
+ - device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
+ - dtype: Data type (e.g., torch.float32) for model inference.
+
+ Returns:
+ - face_helper_1: First face restoration helper.
+ - face_helper_2: Second face restoration helper.
+ - face_clip_model: CLIP model for face extraction.
+ - eva_transform_mean: Mean value for image normalization.
+ - eva_transform_std: Standard deviation value for image normalization.
+ - face_main_model: Main face analysis model.
+ """
+ # get helper model
+ face_helper_1 = FaceRestoreHelper(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model="retinaface_resnet50",
+ save_ext="png",
+ device=device,
+ model_rootpath=os.path.join(model_path, "face_encoder"),
+ )
+ face_helper_1.face_parse = None
+ face_helper_1.face_parse = init_parsing_model(
+ model_name="bisenet", device=device, model_rootpath=os.path.join(model_path, "face_encoder")
+ )
+ face_helper_2 = insightface.model_zoo.get_model(
+ f"{model_path}/face_encoder/models/antelopev2/glintr100.onnx", providers=["CUDAExecutionProvider"]
+ )
+ face_helper_2.prepare(ctx_id=0)
+
+ # get local facial extractor part 1
+ model, _, _ = create_model_and_transforms(
+ "EVA02-CLIP-L-14-336",
+ os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"),
+ force_custom_clip=True,
+ )
+ face_clip_model = model.visual
+ eva_transform_mean = getattr(face_clip_model, "image_mean", OPENAI_DATASET_MEAN)
+ eva_transform_std = getattr(face_clip_model, "image_std", OPENAI_DATASET_STD)
+ if not isinstance(eva_transform_mean, (list, tuple)):
+ eva_transform_mean = (eva_transform_mean,) * 3
+ if not isinstance(eva_transform_std, (list, tuple)):
+ eva_transform_std = (eva_transform_std,) * 3
+ eva_transform_mean = eva_transform_mean
+ eva_transform_std = eva_transform_std
+
+ # get local facial extractor part 2
+ face_main_model = FaceAnalysis(
+ name="antelopev2", root=os.path.join(model_path, "face_encoder"), providers=["CUDAExecutionProvider"]
+ )
+ face_main_model.prepare(ctx_id=0, det_size=(640, 640))
+
+ # move face models to device
+ face_helper_1.face_det.eval()
+ face_helper_1.face_parse.eval()
+ face_clip_model.eval()
+ face_helper_1.face_det.to(device)
+ face_helper_1.face_parse.to(device)
+ face_clip_model.to(device, dtype=dtype)
+
+ return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std
diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py
new file mode 100644
index 000000000000..1a99c2a0e9ee
--- /dev/null
+++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py
@@ -0,0 +1,971 @@
+# Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+import PIL
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import CogVideoXLoraLoaderMixin
+from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel
+from ...models.embeddings import get_3d_rotary_pos_embed
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import CogVideoXDPMScheduler
+from ...utils import logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import ConsisIDPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import ConsisIDPipeline
+ >>> from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
+ >>> from diffusers.utils import export_to_video
+ >>> from huggingface_hub import snapshot_download
+
+ >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
+ >>> (
+ ... face_helper_1,
+ ... face_helper_2,
+ ... face_clip_model,
+ ... face_main_model,
+ ... eva_transform_mean,
+ ... eva_transform_std,
+ ... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
+ >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body).
+ >>> prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."
+ >>> image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true"
+
+ >>> id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
+ ... face_helper_1,
+ ... face_clip_model,
+ ... face_helper_2,
+ ... eva_transform_mean,
+ ... eva_transform_std,
+ ... face_main_model,
+ ... "cuda",
+ ... torch.bfloat16,
+ ... image,
+ ... is_align_face=True,
+ ... )
+
+ >>> video = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... guidance_scale=6.0,
+ ... use_dynamic_cfg=False,
+ ... id_vit_hidden=id_vit_hidden,
+ ... id_cond=id_cond,
+ ... kps_cond=face_kps,
+ ... generator=torch.Generator("cuda").manual_seed(42),
+ ... )
+ >>> export_to_video(video.frames[0], "output.mp4", fps=8)
+ ```
+"""
+
+
+def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
+ """
+ This function draws keypoints and the limbs connecting them on an image.
+
+ Parameters:
+ - image_pil (PIL.Image): Input image as a PIL object.
+ - kps (list of tuples): A list of keypoints where each keypoint is a tuple of (x, y) coordinates.
+ - color_list (list of tuples, optional): List of colors (in RGB format) for each keypoint. Default is a set of five
+ colors.
+
+ Returns:
+ - PIL.Image: Image with the keypoints and limbs drawn.
+ """
+
+ stickwidth = 4
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
+ kps = np.array(kps)
+
+ w, h = image_pil.size
+ out_img = np.zeros([h, w, 3])
+
+ for i in range(len(limbSeq)):
+ index = limbSeq[i]
+ color = color_list[index[0]]
+
+ x = kps[index][:, 0]
+ y = kps[index][:, 1]
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
+ polygon = cv2.ellipse2Poly(
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
+ )
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
+ out_img = (out_img * 0.6).astype(np.uint8)
+
+ for idx_kp, kp in enumerate(kps):
+ color = color_list[idx_kp]
+ x, y = kp
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
+
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
+ return out_img_pil
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ """
+ This function calculates the resize and crop region for an image to fit a target width and height while preserving
+ the aspect ratio.
+
+ Parameters:
+ - src (tuple): A tuple containing the source image's height (h) and width (w).
+ - tgt_width (int): The target width to resize the image.
+ - tgt_height (int): The target height to resize the image.
+
+ Returns:
+ - tuple: Two tuples representing the crop region:
+ 1. The top-left coordinates of the crop region.
+ 2. The bottom-right coordinates of the crop region.
+ """
+
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class ConsisIDPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using ConsisID.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. ConsisID uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`ConsisIDTransformer3DModel`]):
+ A text conditioned `ConsisIDTransformer3DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: ConsisIDTransformer3DModel,
+ scheduler: CogVideoXDPMScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ )
+ self.vae_scaling_factor_image = (
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self,
+ image: torch.Tensor,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ num_frames: int = 13,
+ height: int = 60,
+ width: int = 90,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ kps_cond: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_frames,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ image = image.unsqueeze(2) # [B, C, F, H, W]
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
+ ]
+ if kps_cond is not None:
+ kps_cond = kps_cond.unsqueeze(2)
+ kps_cond_latents = [
+ retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i])
+ for i in range(batch_size)
+ ]
+ else:
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
+ if kps_cond is not None:
+ kps_cond = kps_cond.unsqueeze(2)
+ kps_cond_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in kps_cond]
+
+ image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
+ image_latents = self.vae_scaling_factor_image * image_latents
+
+ if kps_cond is not None:
+ kps_cond_latents = torch.cat(kps_cond_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
+ kps_cond_latents = self.vae_scaling_factor_image * kps_cond_latents
+
+ padding_shape = (
+ batch_size,
+ num_frames - 2,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ else:
+ padding_shape = (
+ batch_size,
+ num_frames - 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
+ if kps_cond is not None:
+ image_latents = torch.cat([image_latents, kps_cond_latents, latent_padding], dim=1)
+ else:
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents, image_latents
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae_scaling_factor_image * latents
+
+ frames = self.vae.decode(latents).sample
+ return frames
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ image,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ latents=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = self.transformer.config.sample_width // self.transformer.config.patch_size
+ base_size_height = self.transformer.config.sample_height // self.transformer.config.patch_size
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ device=device,
+ )
+
+ return freqs_cos, freqs_sin
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ id_vit_hidden: Optional[torch.Tensor] = None,
+ id_cond: Optional[torch.Tensor] = None,
+ kps_cond: Optional[torch.Tensor] = None,
+ ) -> Union[ConsisIDPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
+ num_frames (`int`, defaults to `49`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 6):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ use_dynamic_cfg (`bool`, *optional*, defaults to `False`):
+ If True, dynamically adjusts the guidance scale during inference. This allows the model to use a
+ progressive guidance scale, improving the balance between text-guided generation and image quality over
+ the course of the inference steps. Typically, early inference steps use a higher guidance scale for
+ more faithful image generation, while later steps reduce it for more diverse and natural results.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+ id_vit_hidden (`Optional[torch.Tensor]`, *optional*):
+ The tensor representing the hidden features extracted from the face model, which are used to condition
+ the local facial extractor. This is crucial for the model to obtain high-frequency information of the
+ face. If not provided, the local facial extractor will not run normally.
+ id_cond (`Optional[torch.Tensor]`, *optional*):
+ The tensor representing the hidden features extracted from the clip model, which are used to condition
+ the local facial extractor. This is crucial for the model to edit facial features If not provided, the
+ local facial extractor will not run normally.
+ kps_cond (`Optional[torch.Tensor]`, *optional*):
+ A tensor that determines whether the global facial extractor use keypoint information for conditioning.
+ If provided, this tensor controls whether facial keypoints such as eyes, nose, and mouth landmarks are
+ used during the generation process. This helps ensure the model retains more facial low-frequency
+ information.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] or `tuple`:
+ [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_frames = num_frames or self.transformer.config.sample_frames
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ negative_prompt=negative_prompt,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ latents=latents,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents
+ is_kps = getattr(self.transformer.config, "is_kps", False)
+ kps_cond = kps_cond if is_kps else None
+ if kps_cond is not None:
+ kps_cond = draw_kps(image, kps_cond)
+ kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to(
+ device, dtype=prompt_embeds.dtype
+ )
+
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
+ device, dtype=prompt_embeds.dtype
+ )
+
+ latent_channels = self.transformer.config.in_channels // 2
+ latents, image_latents = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ kps_cond,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ timesteps_cpu = timesteps.cpu()
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ id_vit_hidden=id_vit_hidden,
+ id_cond=id_cond,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (
+ 1
+ - math.cos(
+ math.pi
+ * ((num_inference_steps - timesteps_cpu[i].item()) / num_inference_steps) ** 5.0
+ )
+ )
+ / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return ConsisIDPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/consisid/pipeline_output.py b/src/diffusers/pipelines/consisid/pipeline_output.py
new file mode 100644
index 000000000000..dd4a63aa50b9
--- /dev/null
+++ b/src/diffusers/pipelines/consisid/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class ConsisIDPipelineOutput(BaseOutput):
+ r"""
+ Output class for ConsisID pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
index d2f67a698917..f0c71655e628 100644
--- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
+++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
@@ -19,6 +19,7 @@
from ...models import UNet2DModel
from ...schedulers import CMStochasticIterativeScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -26,6 +27,13 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -263,6 +271,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, sample)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 6. Post-process image sample
image = self.postprocess_image(sample, output_type=output_type)
diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py
index b1671050c93f..a49dccf235a3 100644
--- a/src/diffusers/pipelines/controlnet/__init__.py
+++ b/src/diffusers/pipelines/controlnet/__init__.py
@@ -1,80 +1,86 @@
-from typing import TYPE_CHECKING
-
-from ...utils import (
- DIFFUSERS_SLOW_IMPORT,
- OptionalDependencyNotAvailable,
- _LazyModule,
- get_objects_from_module,
- is_flax_available,
- is_torch_available,
- is_transformers_available,
-)
-
-
-_dummy_objects = {}
-_import_structure = {}
-
-try:
- if not (is_transformers_available() and is_torch_available()):
- raise OptionalDependencyNotAvailable()
-except OptionalDependencyNotAvailable:
- from ...utils import dummy_torch_and_transformers_objects # noqa F403
-
- _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
-else:
- _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
- _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
- _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
- _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
- _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
- _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
- _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
- _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
-try:
- if not (is_transformers_available() and is_flax_available()):
- raise OptionalDependencyNotAvailable()
-except OptionalDependencyNotAvailable:
- from ...utils import dummy_flax_and_transformers_objects # noqa F403
-
- _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
-else:
- _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
-
-
-if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
- try:
- if not (is_transformers_available() and is_torch_available()):
- raise OptionalDependencyNotAvailable()
-
- except OptionalDependencyNotAvailable:
- from ...utils.dummy_torch_and_transformers_objects import *
- else:
- from .multicontrolnet import MultiControlNetModel
- from .pipeline_controlnet import StableDiffusionControlNetPipeline
- from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
- from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
- from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
- from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
- from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
- from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
-
- try:
- if not (is_transformers_available() and is_flax_available()):
- raise OptionalDependencyNotAvailable()
- except OptionalDependencyNotAvailable:
- from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
- else:
- from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
-
-
-else:
- import sys
-
- sys.modules[__name__] = _LazyModule(
- __name__,
- globals()["__file__"],
- _import_structure,
- module_spec=__spec__,
- )
- for name, value in _dummy_objects.items():
- setattr(sys.modules[__name__], name, value)
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_flax_available,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
+ _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
+ _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
+ _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
+ _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
+ _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
+ _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
+ _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
+ _import_structure["pipeline_controlnet_union_inpaint_sd_xl"] = ["StableDiffusionXLControlNetUnionInpaintPipeline"]
+ _import_structure["pipeline_controlnet_union_sd_xl"] = ["StableDiffusionXLControlNetUnionPipeline"]
+ _import_structure["pipeline_controlnet_union_sd_xl_img2img"] = ["StableDiffusionXLControlNetUnionImg2ImgPipeline"]
+try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_flax_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
+else:
+ _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .multicontrolnet import MultiControlNetModel
+ from .pipeline_controlnet import StableDiffusionControlNetPipeline
+ from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
+ from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
+ from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
+ from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
+ from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
+ from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
+ from .pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline
+ from .pipeline_controlnet_union_sd_xl import StableDiffusionXLControlNetUnionPipeline
+ from .pipeline_controlnet_union_sd_xl_img2img import StableDiffusionXLControlNetUnionImg2ImgPipeline
+
+ try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py
index e3c5ec6eed03..6526dd8c9a57 100644
--- a/src/diffusers/pipelines/controlnet/multicontrolnet.py
+++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py
@@ -1,183 +1,12 @@
-import os
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-
-from ...models.controlnet import ControlNetModel, ControlNetOutput
-from ...models.modeling_utils import ModelMixin
-from ...utils import logging
+from ...models.controlnets.multicontrolnet import MultiControlNetModel
+from ...utils import deprecate, logging
logger = logging.get_logger(__name__)
-class MultiControlNetModel(ModelMixin):
- r"""
- Multiple `ControlNetModel` wrapper class for Multi-ControlNet
-
- This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
- compatible with `ControlNetModel`.
-
- Args:
- controlnets (`List[ControlNetModel]`):
- Provides additional conditioning to the unet during the denoising process. You must set multiple
- `ControlNetModel` as a list.
- """
-
- def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
- super().__init__()
- self.nets = nn.ModuleList(controlnets)
-
- def forward(
- self,
- sample: torch.Tensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- controlnet_cond: List[torch.tensor],
- conditioning_scale: List[float],
- class_labels: Optional[torch.Tensor] = None,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- guess_mode: bool = False,
- return_dict: bool = True,
- ) -> Union[ControlNetOutput, Tuple]:
- for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
- down_samples, mid_sample = controlnet(
- sample=sample,
- timestep=timestep,
- encoder_hidden_states=encoder_hidden_states,
- controlnet_cond=image,
- conditioning_scale=scale,
- class_labels=class_labels,
- timestep_cond=timestep_cond,
- attention_mask=attention_mask,
- added_cond_kwargs=added_cond_kwargs,
- cross_attention_kwargs=cross_attention_kwargs,
- guess_mode=guess_mode,
- return_dict=return_dict,
- )
-
- # merge samples
- if i == 0:
- down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
- else:
- down_block_res_samples = [
- samples_prev + samples_curr
- for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
- ]
- mid_block_res_sample += mid_sample
-
- return down_block_res_samples, mid_block_res_sample
-
- def save_pretrained(
- self,
- save_directory: Union[str, os.PathLike],
- is_main_process: bool = True,
- save_function: Callable = None,
- safe_serialization: bool = True,
- variant: Optional[str] = None,
- ):
- """
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
- `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to which to save. Will be created if it doesn't exist.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful when in distributed training like
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
- the main process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
- need to replace `torch.save` by another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
- variant (`str`, *optional*):
- If specified, weights are saved in the format pytorch_model..bin.
- """
- for idx, controlnet in enumerate(self.nets):
- suffix = "" if idx == 0 else f"_{idx}"
- controlnet.save_pretrained(
- save_directory + suffix,
- is_main_process=is_main_process,
- save_function=save_function,
- safe_serialization=safe_serialization,
- variant=variant,
- )
-
- @classmethod
- def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
- r"""
- Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
-
- The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
- the model, you should first set it back in training mode with `model.train()`.
-
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
- task.
-
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
- weights are discarded.
-
- Parameters:
- pretrained_model_path (`os.PathLike`):
- A path to a *directory* containing model weights saved using
- [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
- `./my_model_directory/controlnet`.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
- will be automatically derived from the model's weights.
- output_loading_info(`bool`, *optional*, defaults to `False`):
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
- A map that specifies where each submodule should go. It doesn't need to be refined to each
- parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
- same device.
-
- To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
- more information about each option see [designing a device
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
- max_memory (`Dict`, *optional*):
- A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
- GPU and the available CPU RAM if unset.
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
- Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
- also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
- model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
- setting this argument to `True` will raise an error.
- variant (`str`, *optional*):
- If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
- ignored when using `from_flax`.
- use_safetensors (`bool`, *optional*, defaults to `None`):
- If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
- `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
- `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
- """
- idx = 0
- controlnets = []
-
- # load controlnet and append to list until no controlnet directory exists anymore
- # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
- # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
- model_path_to_load = pretrained_model_path
- while os.path.isdir(model_path_to_load):
- controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
- controlnets.append(controlnet)
-
- idx += 1
- model_path_to_load = pretrained_model_path + f"_{idx}"
-
- logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
-
- if len(controlnets) == 0:
- raise ValueError(
- f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
- )
-
- return cls(controlnets)
+class MultiControlNetModel(MultiControlNetModel):
+ def __init__(self, *args, **kwargs):
+ deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
+ deprecate("diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel", "0.34", deprecation_message)
+ super().__init__(*args, **kwargs)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index 165906b2a643..a5e38278cdf2 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -25,12 +25,13 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,9 +41,15 @@
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from .multicontrolnet import MultiControlNetModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -73,7 +80,7 @@
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
@@ -191,8 +198,8 @@ class StableDiffusionControlNetPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -200,7 +207,7 @@ class StableDiffusionControlNetPipeline(
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image"]
def __init__(
self,
@@ -247,7 +254,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1316,6 +1323,7 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ image = callback_outputs.pop("image", image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1324,6 +1332,8 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
index 86e0ddef663e..88c387d48dd2 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
@@ -21,6 +21,7 @@
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -31,8 +32,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -401,6 +410,10 @@ def __call__(
t,
latents,
)["prev_sample"]
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
index 4cdec5b3cf5f..be2874f48e69 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -24,12 +24,13 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -39,9 +40,15 @@
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from .multicontrolnet import MultiControlNetModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -72,7 +79,7 @@
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
@@ -169,8 +176,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -178,7 +185,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
def __init__(
self,
@@ -225,7 +232,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1287,6 +1294,7 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1295,6 +1303,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index da5a02d14108..16d3529ed38a 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -26,12 +26,13 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -41,9 +42,15 @@
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from .multicontrolnet import MultiControlNetModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -84,7 +91,7 @@
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
... )
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
@@ -142,11 +149,11 @@ class StableDiffusionControlNetInpaintPipeline(
This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
- ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as
- default text-to-image Stable Diffusion checkpoints
- ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image
- Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as
- [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
+ ([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
+ as well as default text-to-image Stable Diffusion checkpoints
+ ([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
+ Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on
+ those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
@@ -168,8 +175,8 @@ class StableDiffusionControlNetInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -177,7 +184,14 @@ class StableDiffusionControlNetInpaintPipeline(
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "control_image",
+ "mask",
+ "masked_image_latents",
+ ]
def __init__(
self,
@@ -224,7 +238,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -1469,6 +1483,7 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1477,6 +1492,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index a1b6de84da46..5907b41f4e73 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -35,7 +35,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -54,13 +54,22 @@
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
-from .multicontrolnet import MultiControlNetModel
if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -228,6 +237,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
"add_neg_time_ids",
"mask",
"masked_image_latents",
+ "control_image",
]
def __init__(
@@ -265,7 +275,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -407,7 +417,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -466,8 +478,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -730,7 +744,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -738,7 +752,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
@@ -1623,7 +1637,7 @@ def denoising_value_valid(dnv):
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
@@ -1631,7 +1645,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1822,6 +1836,7 @@ def denoising_value_valid(dnv):
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1830,6 +1845,9 @@ def denoising_value_valid(dnv):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae()
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 0f3a15172843..77d496cf831d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -38,7 +38,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -61,8 +61,16 @@
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
-from .multicontrolnet import MultiControlNetModel
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -242,6 +250,7 @@ class StableDiffusionXLControlNetPipeline(
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
+ "image",
]
def __init__(
@@ -276,7 +285,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -416,7 +425,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -475,8 +486,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1540,6 +1553,7 @@ def __call__(
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
+ image = callback_outputs.pop("image", image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1548,6 +1562,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index 21cd87f7570e..04f069e12eb9 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -38,7 +38,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -61,8 +61,16 @@
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
-from .multicontrolnet import MultiControlNetModel
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -234,6 +242,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
+ "control_image",
]
def __init__(
@@ -269,7 +278,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -410,7 +419,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -469,8 +480,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1602,6 +1615,7 @@ def __call__(
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+ control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1610,6 +1624,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
new file mode 100644
index 000000000000..8aae9ee7a281
--- /dev/null
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
@@ -0,0 +1,1791 @@
+# Copyright 2024 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+)
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
+ from diffusers.utils import load_image
+ import torch
+ import numpy as np
+ from PIL import Image
+
+ prompt = "A cat"
+ # download an image
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png"
+ ).resize((1024, 1024))
+ mask = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
+ ).resize((1024, 1024))
+ # initialize the models and pipeline
+ controlnet = ControlNetUnionModel.from_pretrained(
+ "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
+ )
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+ pipe = StableDiffusionXLControlNetUnionInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ controlnet=controlnet,
+ vae=vae,
+ torch_dtype=torch.float16,
+ variant="fp16",
+ )
+ pipe.enable_model_cpu_offload()
+ controlnet_img = image.copy()
+ controlnet_img_np = np.array(controlnet_img)
+ mask_np = np.array(mask)
+ controlnet_img_np[mask_np > 0] = 0
+ controlnet_img = Image.fromarray(controlnet_img_np)
+ # generate image
+ image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0]
+ image.save("inpaint.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLControlNetUnionInpaintPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "mask",
+ "masked_image_latents",
+ "control_image",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: ControlNetUnionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ feature_extractor: Optional[CLIPImageProcessor] = None,
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
+ ):
+ super().__init__()
+
+ if not isinstance(controlnet, ControlNetUnionModel):
+ raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ if num_inference_steps is None:
+ raise ValueError("`num_inference_steps` cannot be None.")
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
+ raise ValueError(
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
+ f" {type(num_inference_steps)}."
+ )
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, ControlNetUnionModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_control_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ crops_coords,
+ resize_mode,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ ).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ image_latents = image
+ else:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ masked_image_latents = None
+ if masked_image is not None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ else:
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ t_start = len(self.scheduler.timesteps) - num_inference_steps
+ timesteps = self.scheduler.timesteps[t_start:]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start)
+ return timesteps, num_inference_steps
+
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_mode: Optional[Union[int, List[int]]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+ integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+
+ # # 0.0 Default height and width to unet
+ # height = height or self.unet.config.sample_size * self.vae_scale_factor
+ # width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 0.1 align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+
+ if not isinstance(control_image, list):
+ control_image = [control_image]
+ else:
+ control_image = control_image.copy()
+
+ if not isinstance(control_mode, list):
+ control_mode = [control_mode]
+
+ if len(control_image) != len(control_mode):
+ raise ValueError("Expected len(control_image) == len(control_type)")
+
+ num_control_type = controlnet.config.num_control_type
+
+ # 1. Check inputs
+ control_type = [0 for _ in range(num_control_type)]
+ for _image, control_idx in zip(control_image, control_mode):
+ control_type[control_idx] = 1
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ _image,
+ mask_image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ output_type,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ control_type = torch.Tensor(control_type)
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 3.1 Encode ip_adapter_image
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=denoising_start if denoising_value_valid(denoising_start) else None,
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+ self._num_timesteps = len(timesteps)
+
+ # 5. Preprocess mask and image - resizes image and mask w.r.t height and width
+ # 5.1 Prepare init image
+ if padding_mask_crop is not None:
+ height, width = self.image_processor.get_default_height_width(image, height, width)
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 5.2 Prepare control images
+ for idx, _ in enumerate(control_image):
+ control_image[idx] = self.prepare_control_image(
+ image=control_image[idx],
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = control_image[idx].shape[-2:]
+
+ # 5.3 Prepare mask
+ mask = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ masked_image = init_image * (mask < 0.5)
+ _, _, height, width = init_image.shape
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = True if denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, _ = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8.2 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ controlnet_keep.append(
+ 1.0
+ - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+ for _image in control_image:
+ if isinstance(_image, torch.Tensor):
+ original_size = original_size or _image.shape[-2:]
+
+ # 10. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if (
+ denoising_end is not None
+ and denoising_start is not None
+ and denoising_value_valid(denoising_end)
+ and denoising_value_valid(denoising_start)
+ and denoising_start >= denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {denoising_end} when using type float."
+ )
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ control_type = (
+ control_type.reshape(1, -1)
+ .to(device, dtype=prompt_embeds.dtype)
+ .repeat(batch_size * num_images_per_prompt * 2, 1)
+ )
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ }
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # # Resize control_image to match the size of the input to the controlnet
+ # if control_image.shape[-2:] != control_model_input.shape[-2:]:
+ # control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False)
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ control_type=control_type,
+ control_type_idx=control_mode,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ init_latents_proper = image_latents
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ control_image = callback_outputs.pop("control_image", control_image)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ return StableDiffusionXLPipelineOutput(images=latents)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if padding_mask_crop is not None:
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
new file mode 100644
index 000000000000..ca931c221eec
--- /dev/null
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
@@ -0,0 +1,1616 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.utils.import_utils import is_invisible_watermark_available
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from ...models import (
+ AutoencoderKL,
+ ControlNetUnionModel,
+ ImageProjection,
+ MultiControlNetUnionModel,
+ UNet2DConditionModel,
+)
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+)
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install controlnet_aux
+ >>> from controlnet_aux import LineartAnimeDetector
+ >>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL
+ >>> from diffusers.utils import load_image
+ >>> import torch
+
+ >>> prompt = "A cat"
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
+ ... ).resize((1024, 1024))
+ >>> # initialize the models and pipeline
+ >>> controlnet = ControlNetUnionModel.from_pretrained(
+ ... "xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16
+ ... )
+ >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionXLControlNetUnionPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... controlnet=controlnet,
+ ... vae=vae,
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+ >>> # prepare image
+ >>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
+ >>> controlnet_img = processor(image, output_type="pil")
+ >>> # generate image
+ >>> image = pipe(prompt, control_image=[controlnet_img], control_mode=[3], height=1024, width=1024).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXLControlNetUnionPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
+ Second frozen text-encoder
+ ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ controlnet ([`ControlNetUnionModel`]`):
+ Provides additional conditioning to the `unet` during the denoising process.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings should always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
+ watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
+ watermarker is used.
+ """
+
+ # leave controlnet out on purpose because it iterates with unet
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "feature_extractor",
+ "image_encoder",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[
+ ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
+ ],
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetUnionModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image: PipelineImageInput,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ control_mode=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetUnionModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
+ # Check `image`
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+ elif not all(isinstance(i, list) for i in image):
+ raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for images_ in image:
+ for image_ in images_:
+ self.check_image(image_, prompt, prompt_embeds)
+
+ # Check `controlnet_conditioning_scale`
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.")
+
+ # Check `control_mode`
+ if isinstance(controlnet, ControlNetUnionModel):
+ if max(control_mode) >= controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
+ if max(_control_mode) >= _controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
+
+ # Equal number of `image` and `control_mode` elements
+ if isinstance(controlnet, ControlNetUnionModel):
+ if len(image) != len(control_mode):
+ raise ValueError("Expected len(control_image) == len(control_mode)")
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ if not all(isinstance(i, list) for i in control_mode):
+ raise ValueError(
+ "For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
+ )
+
+ elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
+ raise ValueError("Expected len(control_image) == len(control_mode)")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders.
+ control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, pooled text embeddings are generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
+ argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
+ The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
+ available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
+ where each ControlNet should have its corresponding control mode list. Should reflect the order of
+ conditions in control_image.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned containing the output images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if not isinstance(control_image, list):
+ control_image = [control_image]
+ else:
+ control_image = control_image.copy()
+
+ if not isinstance(control_mode, list):
+ control_mode = [control_mode]
+
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ control_image = [[item] for item in control_image]
+ control_mode = [[item] for item in control_mode]
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ if isinstance(controlnet_conditioning_scale, float):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ control_image,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ negative_pooled_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ control_mode,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
+ for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
+ ]
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetUnionModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3.1 Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ prompt_2,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 3.2 Encode ip_adapter_image
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_images = []
+
+ for image_ in control_image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(image_)
+
+ control_image = control_images
+ height, width = control_image[0].shape[-2:]
+
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ images = []
+
+ for image_ in control_image_:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+ control_images.append(images)
+
+ control_image = control_images
+ height, width = control_image[0][0].shape[-2:]
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps)
+
+ # 7.2 Prepare added time ids & embeddings
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+ for _image in control_image:
+ if isinstance(_image, torch.Tensor):
+ original_size = original_size or _image.shape[-2:]
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ is_unet_compiled = is_compiled_module(self.unet)
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = (
+ control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(batch_size * num_images_per_prompt * 2, 1)
+ )
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ _control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(batch_size * num_images_per_prompt * 2, 1)
+ for _control_type in control_type
+ ]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # Relevant thread:
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ torch._inductor.cudagraph_mark_step_begin()
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ }
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ control_type=control_type,
+ control_type_idx=control_mode,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
new file mode 100644
index 000000000000..87398395d99e
--- /dev/null
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
@@ -0,0 +1,1623 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.utils.import_utils import is_invisible_watermark_available
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+)
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ # !pip install controlnet_aux
+ from diffusers import (
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
+ ControlNetUnionModel,
+ AutoencoderKL,
+ )
+ from diffusers.utils import load_image
+ import torch
+ from PIL import Image
+ import numpy as np
+
+ prompt = "A cat"
+ # download an image
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
+ )
+ # initialize the models and pipeline
+ controlnet = ControlNetUnionModel.from_pretrained(
+ "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
+ )
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+ pipe = StableDiffusionXLControlNetUnionImg2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ controlnet=controlnet,
+ vae=vae,
+ torch_dtype=torch.float16,
+ variant="fp16",
+ ).to("cuda")
+ # `enable_model_cpu_offload` is not recommended due to multiple generations
+ height = image.height
+ width = image.width
+ ratio = np.sqrt(1024.0 * 1024.0 / (width * height))
+ # 3 * 3 upscale correspond to 16 * 3 multiply, 2 * 2 correspond to 16 * 2 multiply and so on.
+ scale_image_factor = 3
+ base_factor = 16
+ factor = scale_image_factor * base_factor
+ W, H = int(width * ratio) // factor * factor, int(height * ratio) // factor * factor
+ image = image.resize((W, H))
+ target_width = W // scale_image_factor
+ target_height = H // scale_image_factor
+ images = []
+ crops_coords_list = [
+ (0, 0),
+ (0, width // 2),
+ (height // 2, 0),
+ (width // 2, height // 2),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ ]
+ for i in range(scale_image_factor):
+ for j in range(scale_image_factor):
+ left = j * target_width
+ top = i * target_height
+ right = left + target_width
+ bottom = top + target_height
+ cropped_image = image.crop((left, top, right, bottom))
+ cropped_image = cropped_image.resize((W, H))
+ images.append(cropped_image)
+ # set ControlNetUnion input
+ result_images = []
+ for sub_img, crops_coords in zip(images, crops_coords_list):
+ new_width, new_height = W, H
+ out = pipe(
+ prompt=[prompt] * 1,
+ image=sub_img,
+ control_image=[sub_img],
+ control_mode=[6],
+ width=new_width,
+ height=new_height,
+ num_inference_steps=30,
+ crops_coords_top_left=(W, H),
+ target_size=(W, H),
+ original_size=(W * 2, H * 2),
+ )
+ result_images.append(out.images[0])
+ new_im = Image.new("RGB", (new_width * scale_image_factor, new_height * scale_image_factor))
+ new_im.paste(result_images[0], (0, 0))
+ new_im.paste(result_images[1], (new_width, 0))
+ new_im.paste(result_images[2], (new_width * 2, 0))
+ new_im.paste(result_images[3], (0, new_height))
+ new_im.paste(result_images[4], (new_width, new_height))
+ new_im.paste(result_images[5], (new_width * 2, new_height))
+ new_im.paste(result_images[6], (0, new_height * 2))
+ new_im.paste(result_images[7], (new_width, new_height * 2))
+ new_im.paste(result_images[8], (new_width * 2, new_height * 2))
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class StableDiffusionXLControlNetUnionImg2ImgPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ FromSingleFileMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ controlnet ([`ControlNetUnionModel`]):
+ Provides additional conditioning to the unet during the denoising process.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the
+ config of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "feature_extractor",
+ "image_encoder",
+ ]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "add_text_embeds", "add_time_ids", "control_image"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: ControlNetUnionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+
+ if not isinstance(controlnet, ControlNetUnionModel):
+ raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ if num_inference_steps is None:
+ raise ValueError("`num_inference_steps` cannot be None.")
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
+ raise ValueError(
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
+ f" {type(num_inference_steps)}."
+ )
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, ControlNetUnionModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
+ def prepare_latents(
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
+ ):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ latents_mean = latents_std = None
+ if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
+
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.text_encoder_2.to("cpu")
+ torch.cuda.empty_cache()
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
+ latents_std = latents_std.to(device=device, dtype=dtype)
+ init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
+ else:
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_mode: Optional[Union[int, List[int]]] = None,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The initial image will be used as the starting point for the image generation process. Can also accept
+ image latents as `image`, if passing latents directly, it will not be encoded again.
+ control_image (`PipelineImageInput`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
+ be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
+ init, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single controlnet.
+ height (`int`, *optional*, defaults to the size of control_image):
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to the size of control_image):
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
+ containing the output images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+
+ if not isinstance(control_image, list):
+ control_image = [control_image]
+ else:
+ control_image = control_image.copy()
+
+ if not isinstance(control_mode, list):
+ control_mode = [control_mode]
+
+ if len(control_image) != len(control_mode):
+ raise ValueError("Expected len(control_image) == len(control_type)")
+
+ num_control_type = controlnet.config.num_control_type
+
+ # 1. Check inputs
+ control_type = [0 for _ in range(num_control_type)]
+ for _image, control_idx in zip(control_image, control_mode):
+ control_type[control_idx] = 1
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ _image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ control_type = torch.Tensor(control_type)
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ global_pool_conditions = controlnet.config.global_pool_conditions
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3.1. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt,
+ prompt_2,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 3.2 Encode ip_adapter_image
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare image and controlnet_conditioning_image
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+
+ for idx, _ in enumerate(control_image):
+ control_image[idx] = self.prepare_control_image(
+ image=control_image[idx],
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = control_image[idx].shape[-2:]
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ if latents is None:
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ True,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ controlnet_keep.append(
+ 1.0
+ - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
+ )
+
+ # 7.2 Prepare added time ids & embeddings
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+ for _image in control_image:
+ if isinstance(_image, torch.Tensor):
+ original_size = original_size or _image.shape[-2:]
+
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+ add_text_embeds = pooled_prompt_embeds
+
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+ control_type = (
+ control_type.reshape(1, -1)
+ .to(device, dtype=prompt_embeds.dtype)
+ .repeat(batch_size * num_images_per_prompt * 2, 1)
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ }
+
+ # controlnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ control_type=control_type,
+ control_type_idx=control_mode,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Inferred ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ control_image = callback_outputs.pop("control_image", control_image)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
index 8a2cc08dbb2b..3d4b19ea552c 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
@@ -75,7 +75,10 @@
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
... )
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ ... controlnet=controlnet,
+ ... revision="flax",
+ ... dtype=jnp.float32,
... )
>>> params["controlnet"] = controlnet_params
@@ -132,8 +135,8 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -175,7 +178,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_text_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
index 45e17f3de1e2..5ee712b5f116 100644
--- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
@@ -232,8 +232,8 @@ def __init__(
Tuple[HunyuanDiT2DControlNetModel],
HunyuanDiT2DMultiControlNetModel,
],
- text_encoder_2=T5EncoderModel,
- tokenizer_2=MT5Tokenizer,
+ text_encoder_2: Optional[T5EncoderModel] = None,
+ tokenizer_2: Optional[MT5Tokenizer] = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -269,9 +269,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
@@ -925,7 +923,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
- self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
+ self.transformer.inner_dim // self.transformer.num_heads,
+ grid_crops_coords,
+ (grid_height, grid_width),
+ device=device,
+ output_type="pt",
)
style = torch.tensor([0], device=device)
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
index 9f674d2d7897..7f7acd882b59 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -19,14 +19,16 @@
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
+ SiglipImageProcessor,
+ SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
+from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
@@ -66,9 +68,13 @@
... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.to("cuda")
- >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
- >>> prompt = "A girl holding a sign that says InstantX"
- >>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7).images[0]
+ >>> control_image = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ ... )
+ >>> prompt = "A bird in space"
+ >>> image = pipe(
+ ... prompt, control_image=control_image, height=1024, width=768, controlnet_conditioning_scale=0.7
+ ... ).images[0]
>>> image.save("sd3.png")
```
"""
@@ -134,7 +140,9 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
+class StableDiffusion3ControlNetPipeline(
+ DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin
+):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -170,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning.
+ image_encoder (`SiglipVisionModel`, *optional*):
+ Pre-trained Vision Model for IP Adapter.
+ feature_extractor (`SiglipImageProcessor`, *optional*):
+ Image processor for IP Adapter.
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -190,10 +202,25 @@ def __init__(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
+ image_encoder: Optional[SiglipVisionModel] = None,
+ feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
controlnet = SD3MultiControlNetModel(controlnet)
+ if isinstance(controlnet, SD3MultiControlNetModel):
+ for controlnet_model in controlnet.nets:
+ # for SD3.5 8b controlnet, it shares the pos_embed with the transformer
+ if (
+ hasattr(controlnet_model.config, "use_pos_embed")
+ and controlnet_model.config.use_pos_embed is False
+ ):
+ pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer)
+ controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device)
+ elif isinstance(controlnet, SD3ControlNetModel):
+ if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False:
+ pos_embed = controlnet._get_pos_embed_from_transformer(transformer)
+ controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)
self.register_modules(
vae=vae,
@@ -206,10 +233,10 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
controlnet=controlnet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -377,9 +404,9 @@ def encode_prompt(
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
- negative_prompt_2 (`str` or `List[str]`, *optional*):
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -710,6 +737,84 @@ def num_timesteps(self):
def interrupt(self):
return self._interrupt
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
+
+ Args:
+ image (`PipelineImageInput`):
+ Input image to be encoded.
+ device: (`torch.device`):
+ Torch device.
+
+ Returns:
+ `torch.Tensor`: The encoded image feature representation.
+ """
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=self.dtype)
+
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ ) -> torch.Tensor:
+ """Prepares image embeddings for use in the IP-Adapter.
+
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
+
+ Args:
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ The input image to extract features from for IP-Adapter.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Precomputed image embeddings.
+ device: (`torch.device`, *optional*):
+ Torch device.
+ num_images_per_prompt (`int`, defaults to 1):
+ Number of images that should be generated per prompt.
+ do_classifier_free_guidance (`bool`, defaults to True):
+ Whether to use classifier free guidance or not.
+ """
+ device = device or self._execution_device
+
+ if ip_adapter_image_embeds is not None:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
+ else:
+ single_image_embeds = ip_adapter_image_embeds
+ elif ip_adapter_image is not None:
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
+ else:
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
+
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+
+ return image_embeds.to(device=device)
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
+ logger.warning(
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
+ )
+
+ super().enable_sequential_cpu_offload(*args, **kwargs)
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -720,7 +825,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
@@ -737,6 +842,8 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -765,10 +872,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -826,6 +933,12 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -858,6 +971,12 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
+ controlnet_config = (
+ self.controlnet.config
+ if isinstance(self.controlnet, SD3ControlNetModel)
+ else self.controlnet.nets[0].config
+ )
+
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
@@ -932,6 +1051,11 @@ def __call__(
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Prepare control image
+ if controlnet_config.force_zeros_for_pooled_projection:
+ # instantx sd3 controlnet does not apply shift factor
+ vae_shift_factor = 0
+ else:
+ vae_shift_factor = self.vae.config.shift_factor
if isinstance(self.controlnet, SD3ControlNetModel):
control_image = self.prepare_image(
image=control_image,
@@ -947,8 +1071,7 @@ def __call__(
height, width = control_image.shape[-2:]
control_image = self.vae.encode(control_image).latent_dist.sample()
- control_image = control_image * self.vae.config.scaling_factor
-
+ control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor
elif isinstance(self.controlnet, SD3MultiControlNetModel):
control_images = []
@@ -966,7 +1089,7 @@ def __call__(
)
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
- control_image_ = control_image_ * self.vae.config.scaling_factor
+ control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor
control_images.append(control_image_)
@@ -974,13 +1097,8 @@ def __call__(
else:
assert False
- if controlnet_pooled_projections is None:
- controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
- else:
- controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
-
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
@@ -1006,7 +1124,34 @@ def __call__(
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
- # 7. Denoising loop
+ if controlnet_config.force_zeros_for_pooled_projection:
+ # instantx sd3 controlnet used zero pooled projection
+ controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
+ else:
+ controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
+
+ if controlnet_config.joint_attention_dim is not None:
+ controlnet_encoder_hidden_states = prompt_embeds
+ else:
+ # SD35 official 8b controlnet does not use encoder_hidden_states
+ controlnet_encoder_hidden_states = None
+
+ # 7. Prepare image embeddings
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
+ else:
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
+
+ # 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
@@ -1029,7 +1174,7 @@ def __call__(
control_block_samples = self.controlnet(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states=controlnet_encoder_hidden_states,
pooled_projections=controlnet_pooled_projections,
joint_attention_kwargs=self.joint_attention_kwargs,
controlnet_cond=control_image,
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
index f362c8f3d0c1..cb35f67fa112 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
@@ -19,14 +19,16 @@
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
+ SiglipImageProcessor,
+ SiglipModel,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
+from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
@@ -159,7 +161,9 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
+class StableDiffusion3ControlNetInpaintingPipeline(
+ DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin
+):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -192,13 +196,17 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]):
- Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+ Provides additional conditioning to the `transformer` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning.
+ image_encoder (`PreTrainedModel`, *optional*):
+ Pre-trained Vision Model for IP Adapter.
+ feature_extractor (`BaseImageProcessor`, *optional*):
+ Image processor for IP Adapter.
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -215,6 +223,8 @@ def __init__(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
+ image_encoder: SiglipModel = None,
+ feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
@@ -229,10 +239,10 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
controlnet=controlnet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True
)
@@ -412,9 +422,9 @@ def encode_prompt(
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
- negative_prompt_2 (`str` or `List[str]`, *optional*):
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -777,6 +787,84 @@ def num_timesteps(self):
def interrupt(self):
return self._interrupt
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
+
+ Args:
+ image (`PipelineImageInput`):
+ Input image to be encoded.
+ device: (`torch.device`):
+ Torch device.
+
+ Returns:
+ `torch.Tensor`: The encoded image feature representation.
+ """
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=self.dtype)
+
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ ) -> torch.Tensor:
+ """Prepares image embeddings for use in the IP-Adapter.
+
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
+
+ Args:
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ The input image to extract features from for IP-Adapter.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Precomputed image embeddings.
+ device: (`torch.device`, *optional*):
+ Torch device.
+ num_images_per_prompt (`int`, defaults to 1):
+ Number of images that should be generated per prompt.
+ do_classifier_free_guidance (`bool`, defaults to True):
+ Whether to use classifier free guidance or not.
+ """
+ device = device or self._execution_device
+
+ if ip_adapter_image_embeds is not None:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
+ else:
+ single_image_embeds = ip_adapter_image_embeds
+ elif ip_adapter_image is not None:
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
+ else:
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
+
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+
+ return image_embeds.to(device=device)
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
+ logger.warning(
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
+ )
+
+ super().enable_sequential_cpu_offload(*args, **kwargs)
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -787,7 +875,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
@@ -805,6 +893,8 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -833,10 +923,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -898,6 +988,12 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1033,7 +1129,7 @@ def __call__(
controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
@@ -1059,7 +1155,22 @@ def __call__(
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps)
- # 7. Denoising loop
+ # 7. Prepare image embeddings
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
+ else:
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
+
+ # 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
index ca10e65de8a4..901ca25c576c 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -41,6 +42,13 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -178,7 +186,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -884,6 +892,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
index 326cfdab7be7..acf1f5489ec1 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -54,6 +54,16 @@
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -196,7 +206,7 @@ def __init__(
scheduler=scheduler,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -336,7 +346,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -395,8 +407,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1074,6 +1088,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae()
diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
index bcd36c412b54..34b2a3945572 100644
--- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
+++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
@@ -17,11 +17,20 @@
import torch
-from ...utils import logging
+from ...models import UNet1DModel
+from ...schedulers import SchedulerMixin
+from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -42,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet"
- def __init__(self, unet, scheduler):
+ def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@@ -146,6 +155,9 @@ def __call__(
# 2. compute previous audio sample: x_t -> t_t-1
audio = self.scheduler.step(model_output, t, audio).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
audio = audio.clamp(-1, 1).float().cpu().numpy()
audio = audio[:, :, :original_sample_size]
diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py
index a3b967ed369b..1fd8ce4e6570 100644
--- a/src/diffusers/pipelines/ddim/pipeline_ddim.py
+++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py
@@ -16,11 +16,21 @@
import torch
+from ...models import UNet2DModel
from ...schedulers import DDIMScheduler
+from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
class DDIMPipeline(DiffusionPipeline):
r"""
Pipeline for image generation.
@@ -38,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet"
- def __init__(self, unet, scheduler):
+ def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__()
# make sure scheduler can always be converted to DDIM
@@ -143,6 +153,9 @@ def __call__(
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
index bb03a8d66758..1c5ac4baeae0 100644
--- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -17,10 +17,21 @@
import torch
+from ...models import UNet2DModel
+from ...schedulers import DDPMScheduler
+from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
class DDPMPipeline(DiffusionPipeline):
r"""
Pipeline for image generation.
@@ -38,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet"
- def __init__(self, unet, scheduler):
+ def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@@ -116,6 +127,9 @@ def __call__(
# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
index f545b24bec5c..150978de6e5e 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
@@ -14,6 +14,7 @@
BACKENDS_MAPPING,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -24,8 +25,16 @@
from .watermark import IFWatermarker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -735,6 +744,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
index 07017912575d..a92d7be6a11c 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
@@ -17,6 +17,7 @@
PIL_INTERPOLATION,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -27,8 +28,16 @@
from .watermark import IFWatermarker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -856,6 +865,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
index 6685ba6d774a..b23ea39bb292 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
@@ -35,6 +35,16 @@
import ftfy
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -174,7 +184,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- if unet.config.in_channels != 6:
+ if unet is not None and unet.config.in_channels != 6:
logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -974,6 +984,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
index 7fca0bc0443c..030821b789aa 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
@@ -17,6 +17,7 @@
PIL_INTERPOLATION,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -27,8 +28,16 @@
from .watermark import IFWatermarker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -975,6 +984,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
index 4f04a1de2a6e..bdad9c29b18f 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
@@ -35,6 +35,16 @@
import ftfy
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -176,7 +186,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- if unet.config.in_channels != 6:
+ if unet is not None and unet.config.in_channels != 6:
logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -1085,6 +1095,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
index 891963f2a904..012c4ca6d448 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
@@ -34,6 +34,16 @@
import ftfy
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -132,7 +142,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- if unet.config.in_channels != 6:
+ if unet is not None and unet.config.in_channels != 6:
logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
)
@@ -831,6 +841,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
index a1930da4180e..48c0aa4f6d76 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
@@ -210,7 +210,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -224,7 +224,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -253,10 +253,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -284,7 +288,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
index e40b6efd71ab..fa70689d790d 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -238,7 +238,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -252,7 +252,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -281,10 +281,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -312,7 +316,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py b/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py
index 101d315dfe59..843528a532f1 100644
--- a/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py
+++ b/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py
@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
scheduler: RePaintScheduler
model_cpu_offload_seq = "unet"
- def __init__(self, unet, scheduler):
+ def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
index 777be883cb9d..1752540e8f79 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
@@ -184,7 +184,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -213,10 +213,14 @@ def __init__(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -243,7 +247,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
index 0aa5e68bfcb4..e9553a8d99b0 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
@@ -93,7 +93,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -107,7 +107,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
index ce7ad3b0dfe9..f9c9c37c4867 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
@@ -140,7 +140,7 @@ def __init__(
)
deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False)
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -154,7 +154,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -183,10 +183,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -213,7 +217,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
index 9e91986896bd..06db871daf62 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
@@ -121,7 +121,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
index be21900ab55a..d486a32f6a4c 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
@@ -143,7 +143,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
index 2978972200c7..509f25620950 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
@@ -365,7 +365,7 @@ def __init__(
caption_generator=caption_generator,
inverse_scheduler=inverse_scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
index 3937e87f63c9..bc276811ff4a 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
@@ -34,7 +34,7 @@
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
from ....models.transformers.transformer_2d import Transformer2DModel
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
-from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ....utils.torch_utils import apply_freeu
@@ -963,10 +963,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -1163,10 +1159,11 @@ def forward(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
+ is_npu = sample.device.type == "npu"
if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
@@ -1595,22 +1592,8 @@ def forward(
output_states = ()
for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -1732,24 +1715,8 @@ def forward(
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -1874,22 +1841,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
@@ -2033,24 +1986,8 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2223,12 +2160,19 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
+ self.gradient_checkpointing = False
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if attn is not None:
- hidden_states = attn(hidden_states, temb=temb)
- hidden_states = resnet(hidden_states, temb)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+ else:
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -2352,18 +2296,7 @@ def forward(
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -2372,12 +2305,7 @@ def custom_forward(*inputs):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
else:
hidden_states = attn(
hidden_states,
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
index c8dc18e2e8ac..4fb437958abd 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
@@ -76,7 +76,7 @@ def __init__(
vae=vae,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
@torch.no_grad()
def image_variation(
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
index 2212651fbb5b..0065279bc0b1 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
@@ -94,7 +94,7 @@ def __init__(
vae=vae,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if self.text_unet is not None and (
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
index 62d3e83a4790..7dfc7e961825 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -77,7 +77,7 @@ def __init__(
vae=vae,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
index de4c2ac9b7f4..1d6771793f39 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -82,7 +82,7 @@ def __init__(
vae=vae,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if self.text_unet is not None:
diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py
index 14321b5f33cf..8aee0fadaf69 100644
--- a/src/diffusers/pipelines/dit/pipeline_dit.py
+++ b/src/diffusers/pipelines/dit/pipeline_dit.py
@@ -24,10 +24,19 @@
from ...models import AutoencoderKL, DiTTransformer2DModel
from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
class DiTPipeline(DiffusionPipeline):
r"""
Pipeline for image generation based on a Transformer backbone instead of a UNet.
@@ -178,10 +187,11 @@ def __call__(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
+ is_npu = latent_model_input.device.type == "npu"
if isinstance(timesteps, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(latent_model_input.device)
@@ -211,6 +221,9 @@ def __call__(
# compute previous image: x_t -> x_t-1
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if guidance_scale > 1:
latents, _ = latent_model_input.chunk(2, dim=0)
else:
diff --git a/src/diffusers/pipelines/easyanimate/__init__.py b/src/diffusers/pipelines/easyanimate/__init__.py
new file mode 100644
index 000000000000..49923423f951
--- /dev/null
+++ b/src/diffusers/pipelines/easyanimate/__init__.py
@@ -0,0 +1,52 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_easyanimate"] = ["EasyAnimatePipeline"]
+ _import_structure["pipeline_easyanimate_control"] = ["EasyAnimateControlPipeline"]
+ _import_structure["pipeline_easyanimate_inpaint"] = ["EasyAnimateInpaintPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_easyanimate import EasyAnimatePipeline
+ from .pipeline_easyanimate_control import EasyAnimateControlPipeline
+ from .pipeline_easyanimate_inpaint import EasyAnimateInpaintPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py
new file mode 100755
index 000000000000..25975b04f395
--- /dev/null
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py
@@ -0,0 +1,770 @@
+# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import (
+ BertModel,
+ BertTokenizer,
+ Qwen2Tokenizer,
+ Qwen2VLForConditionalGeneration,
+)
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import EasyAnimatePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import EasyAnimatePipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh"
+ >>> pipe = EasyAnimatePipeline.from_pretrained(
+ ... "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16
+ ... ).to("cuda")
+ >>> prompt = (
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ ... "atmosphere of this unique musical performance."
+ ... )
+ >>> sample_size = (512, 512)
+ >>> video = pipe(
+ ... prompt=prompt,
+ ... guidance_scale=6,
+ ... negative_prompt="bad detailed",
+ ... height=sample_size[0],
+ ... width=sample_size[1],
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=8)
+ ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class EasyAnimatePipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using EasyAnimate.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
+
+ Args:
+ vae ([`AutoencoderKLMagvit`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
+ text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
+ EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
+ tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
+ A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
+ transformer ([`EasyAnimateTransformer3DModel`]):
+ The EasyAnimate model designed by EasyAnimate Team.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKLMagvit,
+ text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
+ tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
+ transformer: EasyAnimateTransformer3DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.enable_text_attention_mask = (
+ self.transformer.config.enable_text_attention_mask
+ if getattr(self, "transformer", None) is not None
+ else True
+ )
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
+ )
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ dtype (`torch.dtype`):
+ torch dtype
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
+ """
+ dtype = dtype or self.text_encoder.dtype
+ device = device or self.text_encoder.device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ if isinstance(prompt, str):
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": prompt}],
+ }
+ ]
+ else:
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": _prompt}],
+ }
+ for _prompt in prompt
+ ]
+ text = [
+ self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
+ ]
+
+ text_inputs = self.tokenizer(
+ text=text,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_attention_mask=True,
+ padding_side="right",
+ return_tensors="pt",
+ )
+ text_inputs = text_inputs.to(self.text_encoder.device)
+
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ if self.enable_text_attention_mask:
+ # Inference: Generation of the output
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+ ).hidden_states[-2]
+ else:
+ raise ValueError("LLM needs attention_mask")
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": negative_prompt}],
+ }
+ ]
+ else:
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": _negative_prompt}],
+ }
+ for _negative_prompt in negative_prompt
+ ]
+ text = [
+ self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
+ ]
+
+ text_inputs = self.tokenizer(
+ text=text,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_attention_mask=True,
+ padding_side="right",
+ return_tensors="pt",
+ )
+ text_inputs = text_inputs.to(self.text_encoder.device)
+
+ text_input_ids = text_inputs.input_ids
+ negative_prompt_attention_mask = text_inputs.attention_mask
+ if self.enable_text_attention_mask:
+ # Inference: Generation of the output
+ negative_prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=negative_prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-2]
+ else:
+ raise ValueError("LLM needs attention_mask")
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
+
+ return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_temporal_compression_ratio + 1,
+ height // self.vae_spatial_compression_ratio,
+ width // self.vae_spatial_compression_ratio,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # scale the initial noise by the standard deviation required by the scheduler
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_frames: Optional[int] = 49,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ timesteps: Optional[List[int]] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ guidance_rescale: float = 0.0,
+ ):
+ r"""
+ Generates images or video using the EasyAnimate pipeline based on the provided prompts.
+
+ Examples:
+ prompt (`str` or `List[str]`, *optional*):
+ Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
+ num_frames (`int`, *optional*):
+ Length of the generated video (in frames).
+ height (`int`, *optional*):
+ Height of the generated image in pixels.
+ width (`int`, *optional*):
+ Width of the generated image in pixels.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ Number of denoising steps during generation. More steps generally yield higher quality images but slow
+ down inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Encourages the model to align outputs with prompts. A higher value may decrease image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images to generate for each prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A generator to ensure reproducibility in image generation.
+ latents (`torch.Tensor`, *optional*):
+ Predefined latent tensors to condition generation.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Embeddings for negative prompts. Overrides string inputs if defined.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the primary prompt embeddings.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for negative prompt embeddings.
+ output_type (`str`, *optional*, defaults to "latent"):
+ Format of the generated output, either as a PIL image or as a NumPy array.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ If `True`, returns a structured output. Otherwise returns a simple tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ Functions called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
+ Tensor names to be included in callback function calls.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Adjusts noise levels based on guidance scale.
+ original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
+ Original dimensions of the output.
+ target_size (`Tuple[int, int]`, *optional*):
+ Desired output dimensions for calculations.
+ crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
+ Coordinates for cropping.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 0. default height and width
+ height = int((height // 16) * 16)
+ width = int((width // 16) * 16)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ callback_on_step_end_tensor_inputs,
+ )
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = self.transformer.dtype
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ # 4. Prepare timesteps
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, mu=1
+ )
+ else:
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
+
+ prompt_embeds = prompt_embeds.to(device=device)
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ if hasattr(self.scheduler, "scale_model_input"):
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
+ dtype=latent_model_input.dtype
+ )
+
+ # predict the noise residual
+ noise_pred = self.transformer(
+ latent_model_input,
+ t_expand,
+ encoder_hidden_states=prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ if noise_pred.size()[1] != self.vae.config.latent_channels:
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ latents = 1 / self.vae.config.scaling_factor * latents
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return EasyAnimatePipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py
new file mode 100755
index 000000000000..1d2c508675f1
--- /dev/null
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py
@@ -0,0 +1,994 @@
+# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from transformers import (
+ BertModel,
+ BertTokenizer,
+ Qwen2Tokenizer,
+ Qwen2VLForConditionalGeneration,
+)
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import EasyAnimatePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import EasyAnimateControlPipeline
+ >>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent
+ >>> from diffusers.utils import export_to_video, load_video
+
+ >>> pipe = EasyAnimateControlPipeline.from_pretrained(
+ ... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> control_video = load_video(
+ ... "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4"
+ ... )
+ >>> prompt = (
+ ... "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. "
+ ... "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. "
+ ... "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, "
+ ... "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. "
+ ... "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. "
+ ... "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each "
+ ... "releasing their fragrances, creating a relaxed and joyful atmosphere."
+ ... )
+ >>> sample_size = (672, 384)
+ >>> num_frames = 49
+
+ >>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size)
+ >>> video = pipe(
+ ... prompt,
+ ... num_frames=num_frames,
+ ... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
+ ... height=sample_size[0],
+ ... width=sample_size[1],
+ ... control_video=input_video,
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=8)
+ ```
+"""
+
+
+def preprocess_image(image, sample_size):
+ """
+ Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
+ """
+ if isinstance(image, torch.Tensor):
+ # If input is a tensor, assume it's in CHW format and resize using interpolation
+ image = torch.nn.functional.interpolate(
+ image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
+ ).squeeze(0)
+ elif isinstance(image, Image.Image):
+ # If input is a PIL image, resize and convert to numpy array
+ image = image.resize((sample_size[1], sample_size[0]))
+ image = np.array(image)
+ elif isinstance(image, np.ndarray):
+ # If input is a numpy array, resize using PIL
+ image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
+ image = np.array(image)
+ else:
+ raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
+
+ # Convert to tensor if not already
+ if not isinstance(image, torch.Tensor):
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
+
+ return image
+
+
+def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None):
+ if input_video is not None:
+ # Convert each frame in the list to tensor
+ input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video]
+
+ # Stack all frames into a single tensor (F, C, H, W)
+ input_video = torch.stack(input_video)[:num_frames]
+
+ # Add batch dimension (B, F, C, H, W)
+ input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0)
+
+ if validation_video_mask is not None:
+ # Handle mask input
+ validation_video_mask = preprocess_image(validation_video_mask, size=sample_size)
+ input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255)
+
+ # Adjust mask dimensions to match video
+ input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
+ else:
+ input_video_mask = torch.zeros_like(input_video[:, :1])
+ input_video_mask[:, :, :] = 255
+ else:
+ input_video, input_video_mask = None, None
+
+ if ref_image is not None:
+ # Convert reference image to tensor
+ ref_image = preprocess_image(ref_image, size=sample_size)
+ ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W)
+ else:
+ ref_image = None
+
+ return input_video, input_video_mask, ref_image
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Resize mask information in magvit
+def resize_mask(mask, latent, process_first_frame_only=True):
+ latent_size = latent.size()
+
+ if process_first_frame_only:
+ target_size = list(latent_size[2:])
+ target_size[0] = 1
+ first_frame_resized = F.interpolate(
+ mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False
+ )
+
+ target_size = list(latent_size[2:])
+ target_size[0] = target_size[0] - 1
+ if target_size[0] != 0:
+ remaining_frames_resized = F.interpolate(
+ mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False
+ )
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+ else:
+ resized_mask = first_frame_resized
+ else:
+ target_size = list(latent_size[2:])
+ resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False)
+ return resized_mask
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class EasyAnimateControlPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using EasyAnimate.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
+
+ Args:
+ vae ([`AutoencoderKLMagvit`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
+ text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
+ EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
+ tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
+ A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
+ transformer ([`EasyAnimateTransformer3DModel`]):
+ The EasyAnimate model designed by EasyAnimate Team.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKLMagvit,
+ text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
+ tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
+ transformer: EasyAnimateTransformer3DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.enable_text_attention_mask = (
+ self.transformer.config.enable_text_attention_mask
+ if getattr(self, "transformer", None) is not None
+ else True
+ )
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_spatial_compression_ratio,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+
+ # Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ dtype (`torch.dtype`):
+ torch dtype
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
+ """
+ dtype = dtype or self.text_encoder.dtype
+ device = device or self.text_encoder.device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ if isinstance(prompt, str):
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": prompt}],
+ }
+ ]
+ else:
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": _prompt}],
+ }
+ for _prompt in prompt
+ ]
+ text = [
+ self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
+ ]
+
+ text_inputs = self.tokenizer(
+ text=text,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_attention_mask=True,
+ padding_side="right",
+ return_tensors="pt",
+ )
+ text_inputs = text_inputs.to(self.text_encoder.device)
+
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ if self.enable_text_attention_mask:
+ # Inference: Generation of the output
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+ ).hidden_states[-2]
+ else:
+ raise ValueError("LLM needs attention_mask")
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": negative_prompt}],
+ }
+ ]
+ else:
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": _negative_prompt}],
+ }
+ for _negative_prompt in negative_prompt
+ ]
+ text = [
+ self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
+ ]
+
+ text_inputs = self.tokenizer(
+ text=text,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_attention_mask=True,
+ padding_side="right",
+ return_tensors="pt",
+ )
+ text_inputs = text_inputs.to(self.text_encoder.device)
+
+ text_input_ids = text_inputs.input_ids
+ negative_prompt_attention_mask = text_inputs.attention_mask
+ if self.enable_text_attention_mask:
+ # Inference: Generation of the output
+ negative_prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=negative_prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-2]
+ else:
+ raise ValueError("LLM needs attention_mask")
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
+
+ return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_temporal_compression_ratio + 1,
+ height // self.vae_spatial_compression_ratio,
+ width // self.vae_spatial_compression_ratio,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # scale the initial noise by the standard deviation required by the scheduler
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_control_latents(
+ self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the control to latents shape as we concatenate the control to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+
+ if control is not None:
+ control = control.to(device=device, dtype=dtype)
+ bs = 1
+ new_control = []
+ for i in range(0, control.shape[0], bs):
+ control_bs = control[i : i + bs]
+ control_bs = self.vae.encode(control_bs)[0]
+ control_bs = control_bs.mode()
+ new_control.append(control_bs)
+ control = torch.cat(new_control, dim=0)
+ control = control * self.vae.config.scaling_factor
+
+ if control_image is not None:
+ control_image = control_image.to(device=device, dtype=dtype)
+ bs = 1
+ new_control_pixel_values = []
+ for i in range(0, control_image.shape[0], bs):
+ control_pixel_values_bs = control_image[i : i + bs]
+ control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
+ control_pixel_values_bs = control_pixel_values_bs.mode()
+ new_control_pixel_values.append(control_pixel_values_bs)
+ control_image_latents = torch.cat(new_control_pixel_values, dim=0)
+ control_image_latents = control_image_latents * self.vae.config.scaling_factor
+ else:
+ control_image_latents = None
+
+ return control, control_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_frames: Optional[int] = 49,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ control_video: Union[torch.FloatTensor] = None,
+ control_camera_video: Union[torch.FloatTensor] = None,
+ ref_image: Union[torch.FloatTensor] = None,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ guidance_rescale: float = 0.0,
+ timesteps: Optional[List[int]] = None,
+ ):
+ r"""
+ Generates images or video using the EasyAnimate pipeline based on the provided prompts.
+
+ Examples:
+ prompt (`str` or `List[str]`, *optional*):
+ Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
+ num_frames (`int`, *optional*):
+ Length of the generated video (in frames).
+ height (`int`, *optional*):
+ Height of the generated image in pixels.
+ width (`int`, *optional*):
+ Width of the generated image in pixels.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ Number of denoising steps during generation. More steps generally yield higher quality images but slow
+ down inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Encourages the model to align outputs with prompts. A higher value may decrease image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images to generate for each prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A generator to ensure reproducibility in image generation.
+ latents (`torch.Tensor`, *optional*):
+ Predefined latent tensors to condition generation.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Embeddings for negative prompts. Overrides string inputs if defined.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the primary prompt embeddings.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for negative prompt embeddings.
+ output_type (`str`, *optional*, defaults to "latent"):
+ Format of the generated output, either as a PIL image or as a NumPy array.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ If `True`, returns a structured output. Otherwise returns a simple tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ Functions called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
+ Tensor names to be included in callback function calls.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Adjusts noise levels based on guidance scale.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 0. default height and width
+ height = int((height // 16) * 16)
+ width = int((width // 16) * 16)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ callback_on_step_end_tensor_inputs,
+ )
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = self.transformer.dtype
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ text_encoder_index=0,
+ )
+
+ # 4. Prepare timesteps
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, mu=1
+ )
+ else:
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ if control_camera_video is not None:
+ control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
+ control_video_latents = control_video_latents * 6
+ control_latents = (
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
+ ).to(device, dtype)
+ elif control_video is not None:
+ batch_size, channels, num_frames, height_video, width_video = control_video.shape
+ control_video = self.image_processor.preprocess(
+ control_video.permute(0, 2, 1, 3, 4).reshape(
+ batch_size * num_frames, channels, height_video, width_video
+ ),
+ height=height,
+ width=width,
+ )
+ control_video = control_video.to(dtype=torch.float32)
+ control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute(
+ 0, 2, 1, 3, 4
+ )
+ control_video_latents = self.prepare_control_latents(
+ None,
+ control_video,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )[1]
+ control_latents = (
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
+ ).to(device, dtype)
+ else:
+ control_video_latents = torch.zeros_like(latents).to(device, dtype)
+ control_latents = (
+ torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
+ ).to(device, dtype)
+
+ if ref_image is not None:
+ batch_size, channels, num_frames, height_video, width_video = ref_image.shape
+ ref_image = self.image_processor.preprocess(
+ ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video),
+ height=height,
+ width=width,
+ )
+ ref_image = ref_image.to(dtype=torch.float32)
+ ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
+
+ ref_image_latents = self.prepare_control_latents(
+ None,
+ ref_image,
+ batch_size,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )[1]
+
+ ref_image_latents_conv_in = torch.zeros_like(latents)
+ if latents.size()[2] != 1:
+ ref_image_latents_conv_in[:, :, :1] = ref_image_latents
+ ref_image_latents_conv_in = (
+ torch.cat([ref_image_latents_conv_in] * 2)
+ if self.do_classifier_free_guidance
+ else ref_image_latents_conv_in
+ ).to(device, dtype)
+ control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
+ else:
+ ref_image_latents_conv_in = torch.zeros_like(latents)
+ ref_image_latents_conv_in = (
+ torch.cat([ref_image_latents_conv_in] * 2)
+ if self.do_classifier_free_guidance
+ else ref_image_latents_conv_in
+ ).to(device, dtype)
+ control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
+
+ # To latents.device
+ prompt_embeds = prompt_embeds.to(device=device)
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ if hasattr(self.scheduler, "scale_model_input"):
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
+ dtype=latent_model_input.dtype
+ )
+ # predict the noise residual
+ noise_pred = self.transformer(
+ latent_model_input,
+ t_expand,
+ encoder_hidden_states=prompt_embeds,
+ control_latents=control_latents,
+ return_dict=False,
+ )[0]
+ if noise_pred.size()[1] != self.vae.config.latent_channels:
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # Convert to tensor
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return EasyAnimatePipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
new file mode 100755
index 000000000000..15745ecca3f0
--- /dev/null
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
@@ -0,0 +1,1234 @@
+# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from transformers import (
+ BertModel,
+ BertTokenizer,
+ Qwen2Tokenizer,
+ Qwen2VLForConditionalGeneration,
+)
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import EasyAnimatePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import EasyAnimateInpaintPipeline
+ >>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_inpaint import get_image_to_video_latent
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> pipe = EasyAnimateInpaintPipeline.from_pretrained(
+ ... "alibaba-pai/EasyAnimateV5.1-12b-zh-InP-diffusers", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+ >>> validation_image_start = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+
+ >>> validation_image_end = None
+ >>> sample_size = (448, 576)
+ >>> num_frames = 49
+ >>> input_video, input_video_mask = get_image_to_video_latent(
+ ... [validation_image_start], validation_image_end, num_frames, sample_size
+ ... )
+
+ >>> video = pipe(
+ ... prompt,
+ ... num_frames=num_frames,
+ ... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
+ ... height=sample_size[0],
+ ... width=sample_size[1],
+ ... video=input_video,
+ ... mask_video=input_video_mask,
+ ... )
+ >>> export_to_video(video.frames[0], "output.mp4", fps=8)
+ ```
+"""
+
+
+def preprocess_image(image, sample_size):
+ """
+ Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
+ """
+ if isinstance(image, torch.Tensor):
+ # If input is a tensor, assume it's in CHW format and resize using interpolation
+ image = torch.nn.functional.interpolate(
+ image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
+ ).squeeze(0)
+ elif isinstance(image, Image.Image):
+ # If input is a PIL image, resize and convert to numpy array
+ image = image.resize((sample_size[1], sample_size[0]))
+ image = np.array(image)
+ elif isinstance(image, np.ndarray):
+ # If input is a numpy array, resize using PIL
+ image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
+ image = np.array(image)
+ else:
+ raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
+
+ # Convert to tensor if not already
+ if not isinstance(image, torch.Tensor):
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
+
+ return image
+
+
+def get_image_to_video_latent(validation_image_start, validation_image_end, num_frames, sample_size):
+ """
+ Generate latent representations for video from start and end images. Inputs can be PIL.Image, numpy.ndarray, or
+ torch.Tensor.
+ """
+ input_video = None
+ input_video_mask = None
+
+ if validation_image_start is not None:
+ # Preprocess the starting image(s)
+ if isinstance(validation_image_start, list):
+ image_start = [preprocess_image(img, sample_size) for img in validation_image_start]
+ else:
+ image_start = preprocess_image(validation_image_start, sample_size)
+
+ # Create video tensor from the starting image(s)
+ if isinstance(image_start, list):
+ start_video = torch.cat(
+ [img.unsqueeze(1).unsqueeze(0) for img in image_start],
+ dim=2,
+ )
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, num_frames, 1, 1])
+ input_video[:, :, : len(image_start)] = start_video
+ else:
+ input_video = torch.tile(
+ image_start.unsqueeze(1).unsqueeze(0),
+ [1, 1, num_frames, 1, 1],
+ )
+
+ # Normalize input video (already normalized in preprocess_image)
+
+ # Create mask for the input video
+ input_video_mask = torch.zeros_like(input_video[:, :1])
+ if isinstance(image_start, list):
+ input_video_mask[:, :, len(image_start) :] = 255
+ else:
+ input_video_mask[:, :, 1:] = 255
+
+ # Handle ending image(s) if provided
+ if validation_image_end is not None:
+ if isinstance(validation_image_end, list):
+ image_end = [preprocess_image(img, sample_size) for img in validation_image_end]
+ end_video = torch.cat(
+ [img.unsqueeze(1).unsqueeze(0) for img in image_end],
+ dim=2,
+ )
+ input_video[:, :, -len(end_video) :] = end_video
+ input_video_mask[:, :, -len(image_end) :] = 0
+ else:
+ image_end = preprocess_image(validation_image_end, sample_size)
+ input_video[:, :, -1:] = image_end.unsqueeze(1).unsqueeze(0)
+ input_video_mask[:, :, -1:] = 0
+
+ elif validation_image_start is None:
+ # If no starting image is provided, initialize empty tensors
+ input_video = torch.zeros([1, 3, num_frames, sample_size[0], sample_size[1]])
+ input_video_mask = torch.ones([1, 1, num_frames, sample_size[0], sample_size[1]]) * 255
+
+ return input_video, input_video_mask
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Resize mask information in magvit
+def resize_mask(mask, latent, process_first_frame_only=True):
+ latent_size = latent.size()
+
+ if process_first_frame_only:
+ target_size = list(latent_size[2:])
+ target_size[0] = 1
+ first_frame_resized = F.interpolate(
+ mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False
+ )
+
+ target_size = list(latent_size[2:])
+ target_size[0] = target_size[0] - 1
+ if target_size[0] != 0:
+ remaining_frames_resized = F.interpolate(
+ mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False
+ )
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+ else:
+ resized_mask = first_frame_resized
+ else:
+ target_size = list(latent_size[2:])
+ resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False)
+ return resized_mask
+
+
+## Add noise to reference video
+def add_noise_to_reference_video(image, ratio=None, generator=None):
+ if ratio is None:
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
+ sigma = torch.exp(sigma).to(image.dtype)
+ else:
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
+
+ if generator is not None:
+ image_noise = (
+ torch.randn(image.size(), generator=generator, dtype=image.dtype, device=image.device)
+ * sigma[:, None, None, None, None]
+ )
+ else:
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
+ image_noise = torch.where(image == -1, torch.zeros_like(image), image_noise)
+ image = image + image_noise
+ return image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class EasyAnimateInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using EasyAnimate.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
+
+ Args:
+ vae ([`AutoencoderKLMagvit`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
+ text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
+ EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
+ tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
+ A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
+ transformer ([`EasyAnimateTransformer3DModel`]):
+ The EasyAnimate model designed by EasyAnimate Team.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKLMagvit,
+ text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
+ tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
+ transformer: EasyAnimateTransformer3DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.enable_text_attention_mask = (
+ self.transformer.config.enable_text_attention_mask
+ if getattr(self, "transformer", None) is not None
+ else True
+ )
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_spatial_compression_ratio,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+
+ # Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ dtype (`torch.dtype`):
+ torch dtype
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
+ max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
+ """
+ dtype = dtype or self.text_encoder.dtype
+ device = device or self.text_encoder.device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ if isinstance(prompt, str):
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": prompt}],
+ }
+ ]
+ else:
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": _prompt}],
+ }
+ for _prompt in prompt
+ ]
+ text = [
+ self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
+ ]
+
+ text_inputs = self.tokenizer(
+ text=text,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_attention_mask=True,
+ padding_side="right",
+ return_tensors="pt",
+ )
+ text_inputs = text_inputs.to(self.text_encoder.device)
+
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ if self.enable_text_attention_mask:
+ # Inference: Generation of the output
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+ ).hidden_states[-2]
+ else:
+ raise ValueError("LLM needs attention_mask")
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": negative_prompt}],
+ }
+ ]
+ else:
+ messages = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": _negative_prompt}],
+ }
+ for _negative_prompt in negative_prompt
+ ]
+ text = [
+ self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
+ ]
+
+ text_inputs = self.tokenizer(
+ text=text,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_attention_mask=True,
+ padding_side="right",
+ return_tensors="pt",
+ )
+ text_inputs = text_inputs.to(self.text_encoder.device)
+
+ text_input_ids = text_inputs.input_ids
+ negative_prompt_attention_mask = text_inputs.attention_mask
+ if self.enable_text_attention_mask:
+ # Inference: Generation of the output
+ negative_prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=negative_prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-2]
+ else:
+ raise ValueError("LLM needs attention_mask")
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
+
+ return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ noise_aug_strength,
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ if mask is not None:
+ mask = mask.to(device=device, dtype=dtype)
+ new_mask = []
+ bs = 1
+ for i in range(0, mask.shape[0], bs):
+ mask_bs = mask[i : i + bs]
+ mask_bs = self.vae.encode(mask_bs)[0]
+ mask_bs = mask_bs.mode()
+ new_mask.append(mask_bs)
+ mask = torch.cat(new_mask, dim=0)
+ mask = mask * self.vae.config.scaling_factor
+
+ if masked_image is not None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ if self.transformer.config.add_noise_in_inpaint_model:
+ masked_image = add_noise_to_reference_video(
+ masked_image, ratio=noise_aug_strength, generator=generator
+ )
+ new_mask_pixel_values = []
+ bs = 1
+ for i in range(0, masked_image.shape[0], bs):
+ mask_pixel_values_bs = masked_image[i : i + bs]
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
+ new_mask_pixel_values.append(mask_pixel_values_bs)
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim=0)
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ else:
+ masked_image_latents = None
+
+ return mask, masked_image_latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ video=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_video_latents=False,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_temporal_compression_ratio + 1,
+ height // self.vae_spatial_compression_ratio,
+ width // self.vae_spatial_compression_ratio,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if return_video_latents or (latents is None and not is_strength_max):
+ video = video.to(device=device, dtype=dtype)
+ bs = 1
+ new_video = []
+ for i in range(0, video.shape[0], bs):
+ video_bs = video[i : i + bs]
+ video_bs = self.vae.encode(video_bs)[0]
+ video_bs = video_bs.sample()
+ new_video.append(video_bs)
+ video = torch.cat(new_video, dim=0)
+ video = video * self.vae.config.scaling_factor
+
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
+ video_latents = video_latents.to(device=device, dtype=dtype)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ latents = noise if is_strength_max else self.scheduler.scale_noise(video_latents, timestep, noise)
+ else:
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ else:
+ if hasattr(self.scheduler, "init_noise_sigma"):
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_video_latents:
+ outputs += (video_latents,)
+
+ return outputs
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_frames: Optional[int] = 49,
+ video: Union[torch.FloatTensor] = None,
+ mask_video: Union[torch.FloatTensor] = None,
+ masked_video_latents: Union[torch.FloatTensor] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ guidance_rescale: float = 0.0,
+ strength: float = 1.0,
+ noise_aug_strength: float = 0.0563,
+ timesteps: Optional[List[int]] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation with HunyuanDiT.
+
+ Examples:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ num_frames (`int`, *optional*):
+ Length of the video to be generated in seconds. This parameter influences the number of frames and
+ continuity of generated content.
+ video (`torch.FloatTensor`, *optional*):
+ A tensor representing an input video, which can be modified depending on the prompts provided.
+ mask_video (`torch.FloatTensor`, *optional*):
+ A tensor to specify areas of the video to be masked (omitted from generation).
+ masked_video_latents (`torch.FloatTensor`, *optional*):
+ Latents from masked portions of the video, utilized during image generation.
+ height (`int`, *optional*):
+ The height in pixels of the generated image or video frames.
+ width (`int`, *optional*):
+ The width in pixels of the generated image or video frames.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image but slower
+ inference time. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to exclude in image generation. If not defined, you need to provide
+ `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
+ [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the
+ inference process.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting
+ random seeds which helps in making generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ A pre-computed latent representation which can be used to guide the generation process.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the
+ outputs. If not provided, embeddings are generated from the `negative_prompt` argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using
+ `prompt_embeds`.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used.
+ output_type (`str`, *optional*, defaults to `"latent"`):
+ The output format of the generated image. Choose between `PIL.Image` and `np.array` to define how you
+ want the results to be formatted.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned;
+ otherwise, a tuple containing the generated images and safety flags will be returned.
+ callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`,
+ *optional*):
+ A callback function (or a list of them) that will be executed at the end of each denoising step,
+ allowing for custom processing during generation.
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
+ Specifies which tensor inputs should be included in the callback function. If not defined, all tensor
+ inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ strength (`float`, *optional*, defaults to 1.0):
+ Affects the overall styling or quality of the generated output. Values closer to 1 usually provide
+ direct adherence to prompts.
+
+ Examples:
+ # Example usage of the function for generating images based on prompts.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ Returns either a structured output containing generated images and their metadata when `return_dict` is
+ `True`, or a simpler tuple, where the first element is a list of generated images and the second
+ element indicates if any of them contain "not-safe-for-work" (NSFW) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 0. default height and width
+ height = int(height // 16 * 16)
+ width = int(width // 16 * 16)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ callback_on_step_end_tensor_inputs,
+ )
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = self.transformer.dtype
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ # 4. set timesteps
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, mu=1
+ )
+ else:
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ if video is not None:
+ batch_size, channels, num_frames, height_video, width_video = video.shape
+ init_video = self.image_processor.preprocess(
+ video.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video),
+ height=height,
+ width=width,
+ )
+ init_video = init_video.to(dtype=torch.float32)
+ init_video = init_video.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
+ else:
+ init_video = None
+
+ # Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_transformer = self.transformer.config.in_channels
+ return_image_latents = num_channels_transformer == num_channels_latents
+
+ # 5. Prepare latents.
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ dtype,
+ device,
+ generator,
+ latents,
+ video=init_video,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_video_latents=return_image_latents,
+ )
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 6. Prepare inpaint latents if it needs.
+ if mask_video is not None:
+ if (mask_video == 255).all():
+ mask = torch.zeros_like(latents).to(device, dtype)
+ # Use zero latents if we want to t2v.
+ if self.transformer.config.resize_inpaint_mask_directly:
+ mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype)
+ else:
+ mask_latents = torch.zeros_like(latents).to(device, dtype)
+ masked_video_latents = torch.zeros_like(latents).to(device, dtype)
+
+ mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
+ )
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
+ else:
+ # Prepare mask latent variables
+ batch_size, channels, num_frames, height_video, width_video = mask_video.shape
+ mask_condition = self.mask_processor.preprocess(
+ mask_video.permute(0, 2, 1, 3, 4).reshape(
+ batch_size * num_frames, channels, height_video, width_video
+ ),
+ height=height,
+ width=width,
+ )
+ mask_condition = mask_condition.to(dtype=torch.float32)
+ mask_condition = mask_condition.reshape(batch_size, num_frames, channels, height, width).permute(
+ 0, 2, 1, 3, 4
+ )
+
+ if num_channels_transformer != num_channels_latents:
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
+ if masked_video_latents is None:
+ masked_video = (
+ init_video * (mask_condition_tile < 0.5)
+ + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
+ )
+ else:
+ masked_video = masked_video_latents
+
+ if self.transformer.config.resize_inpaint_mask_directly:
+ _, masked_video_latents = self.prepare_mask_latents(
+ None,
+ masked_video,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ noise_aug_strength=noise_aug_strength,
+ )
+ mask_latents = resize_mask(
+ 1 - mask_condition, masked_video_latents, self.vae.config.cache_mag_vae
+ )
+ mask_latents = mask_latents.to(device, dtype) * self.vae.config.scaling_factor
+ else:
+ mask_latents, masked_video_latents = self.prepare_mask_latents(
+ mask_condition_tile,
+ masked_video,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ noise_aug_strength=noise_aug_strength,
+ )
+
+ mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2)
+ if self.do_classifier_free_guidance
+ else masked_video_latents
+ )
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
+ else:
+ inpaint_latents = None
+
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode="trilinear", align_corners=True).to(
+ device, dtype
+ )
+ else:
+ if num_channels_transformer != num_channels_latents:
+ mask = torch.zeros_like(latents).to(device, dtype)
+ if self.transformer.config.resize_inpaint_mask_directly:
+ mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype)
+ else:
+ mask_latents = torch.zeros_like(latents).to(device, dtype)
+ masked_video_latents = torch.zeros_like(latents).to(device, dtype)
+
+ mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
+ )
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
+ else:
+ mask = torch.zeros_like(init_video[:, :1])
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode="trilinear", align_corners=True).to(
+ device, dtype
+ )
+
+ inpaint_latents = None
+
+ # Check that sizes of mask, masked image and latents match
+ if num_channels_transformer != num_channels_latents:
+ num_channels_mask = mask_latents.shape[1]
+ num_channels_masked_image = masked_video_latents.shape[1]
+ if (
+ num_channels_latents + num_channels_mask + num_channels_masked_image
+ != self.transformer.config.in_channels
+ ):
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
+ f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.transformer` or your `mask_image` or `image` input."
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
+
+ # To latents.device
+ prompt_embeds = prompt_embeds.to(device=device)
+ prompt_attention_mask = prompt_attention_mask.to(device=device)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ if hasattr(self.scheduler, "scale_model_input"):
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
+ t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
+ dtype=latent_model_input.dtype
+ )
+
+ # predict the noise residual
+ noise_pred = self.transformer(
+ latent_model_input,
+ t_expand,
+ encoder_hidden_states=prompt_embeds,
+ inpaint_latents=inpaint_latents,
+ return_dict=False,
+ )[0]
+ if noise_pred.size()[1] != self.vae.config.latent_channels:
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_transformer == num_channels_latents:
+ init_latents_proper = image_latents
+ init_mask = mask
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep], noise)
+ )
+ else:
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ latents = 1 / self.vae.config.scaling_factor * latents
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return EasyAnimatePipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_output.py b/src/diffusers/pipelines/easyanimate/pipeline_output.py
new file mode 100644
index 000000000000..c761a3b1079f
--- /dev/null
+++ b/src/diffusers/pipelines/easyanimate/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class EasyAnimatePipelineOutput(BaseOutput):
+ r"""
+ Output class for EasyAnimate pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py
index 0ebf5ea6d78d..72e1b578f2ca 100644
--- a/src/diffusers/pipelines/flux/__init__.py
+++ b/src/diffusers/pipelines/flux/__init__.py
@@ -12,7 +12,7 @@
_dummy_objects = {}
_additional_imports = {}
-_import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
+_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
@@ -22,12 +22,18 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
+ _import_structure["modeling_flux"] = ["ReduxImageEncoder"]
_import_structure["pipeline_flux"] = ["FluxPipeline"]
+ _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
+ _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
+ _import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
+ _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
+ _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -35,12 +41,18 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
+ from .modeling_flux import ReduxImageEncoder
from .pipeline_flux import FluxPipeline
+ from .pipeline_flux_control import FluxControlPipeline
+ from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
+ from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
+ from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
+ from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/flux/modeling_flux.py b/src/diffusers/pipelines/flux/modeling_flux.py
new file mode 100644
index 000000000000..5ff60f774d19
--- /dev/null
+++ b/src/diffusers/pipelines/flux/modeling_flux.py
@@ -0,0 +1,47 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.modeling_utils import ModelMixin
+from ...utils import BaseOutput
+
+
+@dataclass
+class ReduxImageEncoderOutput(BaseOutput):
+ image_embeds: Optional[torch.Tensor] = None
+
+
+class ReduxImageEncoder(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ redux_dim: int = 1152,
+ txt_in_features: int = 4096,
+ ) -> None:
+ super().__init__()
+
+ self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
+ self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
+
+ def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput:
+ projected_x = self.redux_down(nn.functional.silu(self.redux_up(x)))
+
+ return ReduxImageEncoderOutput(image_embeds=projected_x)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 8278365e9467..862c279cfaf3 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -17,12 +17,18 @@
import numpy as np
import torch
-from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
-from ...image_processor import VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
-from ...models.autoencoders import AutoencoderKL
-from ...models.transformers import FluxTransformer2DModel
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
@@ -69,7 +75,7 @@ def calculate_shift(
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
- max_shift: float = 1.16,
+ max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
@@ -142,6 +148,7 @@ class FluxPipeline(
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
):
r"""
The Flux pipeline for text-to-image generation.
@@ -169,8 +176,8 @@ class FluxPipeline(
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
@@ -182,6 +189,8 @@ def __init__(
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
@@ -193,15 +202,17 @@ def __init__(
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
- )
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
- self.default_sample_size = 64
+ self.default_sample_size = 128
def _get_t5_prompt_embeds(
self,
@@ -375,19 +386,72 @@ def encode_prompt(
return prompt_embeds, pooled_prompt_embeds, text_ids
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
+ negative_prompt=None,
+ negative_prompt_2=None,
prompt_embeds=None,
+ negative_prompt_embeds=None,
pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -415,19 +479,42 @@ def check_inputs(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
@@ -449,13 +536,15 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
- height = height // vae_scale_factor
- width = width // vae_scale_factor
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@@ -499,13 +588,15 @@ def prepare_latents(
generator,
latents=None,
):
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
@@ -517,7 +608,7 @@ def prepare_latents(
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
@@ -533,6 +624,10 @@ def joint_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
@property
def interrupt(self):
return self._interrupt
@@ -543,16 +638,25 @@ def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -569,7 +673,16 @@ def __call__(
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
- will be used instead
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -577,11 +690,11 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 7.0):
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
@@ -602,6 +715,25 @@ def __call__(
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -639,14 +771,19 @@ def __call__(
prompt_2,
height,
width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
@@ -662,6 +799,10 @@ def __call__(
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
(
prompt_embeds,
pooled_prompt_embeds,
@@ -676,6 +817,21 @@ def __call__(
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ _,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
@@ -691,21 +847,20 @@ def __call__(
)
# 5. Prepare timesteps
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
- timesteps,
- sigmas,
+ sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -718,12 +873,47 @@ def __call__(
else:
guidance = None
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -739,6 +929,22 @@ def __call__(
return_dict=False,
)[0]
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -764,9 +970,10 @@ def __call__(
if XLA_AVAILABLE:
xm.mark_step()
+ self._current_timestep = None
+
if output_type == "latent":
image = latents
-
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py
new file mode 100644
index 000000000000..113b0dd7291f
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py
@@ -0,0 +1,886 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from controlnet_aux import CannyDetector
+ >>> from diffusers import FluxControlPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxControlPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16
+ ... ).to("cuda")
+
+ >>> prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
+ >>> control_image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
+ ... )
+
+ >>> processor = CannyDetector()
+ >>> control_image = processor(
+ ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
+ ... )
+
+ >>> image = pipe(
+ ... prompt=prompt,
+ ... control_image=control_image,
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=50,
+ ... guidance_scale=30.0,
+ ... ).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxControlPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ The Flux pipeline for controllable text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.vae_latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.vae_latent_channels
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 8
+
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ if control_image.ndim == 4:
+ control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents, control_image], dim=2)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
new file mode 100644
index 000000000000..c269be15a4b2
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
@@ -0,0 +1,941 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from controlnet_aux import CannyDetector
+ >>> from diffusers import FluxControlImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxControlImg2ImgPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16
+ ... ).to("cuda")
+
+ >>> prompt = "A robot made of exotic candies and chocolates of different kinds. Abstract background"
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/watercolor-painting.jpg"
+ ... )
+ >>> control_image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
+ ... )
+
+ >>> processor = CannyDetector()
+ >>> control_image = processor(
+ ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
+ ... )
+
+ >>> image = pipe(
+ ... prompt=prompt,
+ ... image=image,
+ ... control_image=control_image,
+ ... strength=0.8,
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=50,
+ ... guidance_scale=30.0,
+ ... ).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ The Flux pipeline for image inpainting.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ strength,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ return latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ strength,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Preprocess image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4.Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 8
+
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ if control_image.ndim == 4:
+ control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents, control_image], dim=2)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
new file mode 100644
index 000000000000..af7e8b53fad3
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
@@ -0,0 +1,1139 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import (
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+)
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ import torch
+ from diffusers import FluxControlInpaintPipeline
+ from diffusers.models.transformers import FluxTransformer2DModel
+ from transformers import T5EncoderModel
+ from diffusers.utils import load_image, make_image_grid
+ from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
+ from PIL import Image
+ import numpy as np
+
+ pipe = FluxControlInpaintPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Depth-dev",
+ torch_dtype=torch.bfloat16,
+ )
+ # use following lines if you have GPU constraints
+ # ---------------------------------------------------------------
+ transformer = FluxTransformer2DModel.from_pretrained(
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
+ )
+ text_encoder_2 = T5EncoderModel.from_pretrained(
+ "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
+ )
+ pipe.transformer = transformer
+ pipe.text_encoder_2 = text_encoder_2
+ pipe.enable_model_cpu_offload()
+ # ---------------------------------------------------------------
+ pipe.to("cuda")
+
+ prompt = "a blue robot singing opera with human-like expressions"
+ image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
+
+ head_mask = np.zeros_like(image)
+ head_mask[65:580, 300:642] = 255
+ mask_image = Image.fromarray(head_mask)
+
+ processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
+ control_image = processor(image)[0].convert("RGB")
+
+ output = pipe(
+ prompt=prompt,
+ image=image,
+ control_image=control_image,
+ mask_image=mask_image,
+ num_inference_steps=30,
+ strength=0.9,
+ guidance_scale=10.0,
+ generator=torch.Generator().manual_seed(42),
+ ).images[0]
+ make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save(
+ "output.png"
+ )
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FluxControlInpaintPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ The Flux pipeline for image inpainting using Flux-dev-Depth/Canny.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ strength,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ return latents, noise, image_latents, latent_image_ids
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_mask_latents(
+ self,
+ image,
+ mask_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
+
+ masked_image = image * (1 - mask_image)
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask_image = torch.nn.functional.interpolate(mask_image, size=(height, width))
+ mask_image = mask_image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == num_channels_latents:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask_image.shape[0] < batch_size:
+ if not batch_size % mask_image.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask_image.shape[0]} mask_image were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask_image = self._pack_latents(
+ mask_image.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ masked_image_latents = torch.cat((masked_image_latents, mask_image), dim=-1)
+
+ return mask_image, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
+ latents tensor will ge generated by `mask_image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ strength,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+ device = self._execution_device
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 3. Preprocess mask and image
+ num_channels_latents = self.vae.config.latent_channels
+ if masked_image_latents is not None:
+ # pre computed masked_image_latents and mask_image
+ masked_image_latents = masked_image_latents.to(latents.device)
+ mask = mask_image.to(latents.device)
+ else:
+ mask, masked_image_latents = self.prepare_mask_latents(
+ image,
+ mask_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 4.Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 8
+
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ if control_image.ndim == 4:
+ control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator)
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
+
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height_8 = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width_8 = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents, control_image], dim=2)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # for 64 channel transformer only.
+ init_mask = mask
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ image_latents, torch.tensor([noise_timestep]), noise
+ )
+ else:
+ init_latents_proper = image_latents
+ init_latents_proper = self._pack_latents(
+ init_latents_proper, batch_size * num_images_per_prompt, num_channels_latents, height_8, width_8
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
index 5136c4200147..f3f1d90204d6 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
@@ -18,16 +18,18 @@
import numpy as np
import torch
from transformers import (
+ CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
+ CLIPVisionModelWithProjection,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
@@ -61,6 +63,7 @@
>>> from diffusers import FluxControlNetPipeline
>>> from diffusers import FluxControlNetModel
+ >>> base_model = "black-forest-labs/FLUX.1-dev"
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
>>> pipe = FluxControlNetPipeline.from_pretrained(
@@ -89,7 +92,7 @@ def calculate_shift(
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
- max_shift: float = 1.16,
+ max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
@@ -97,6 +100,20 @@ def calculate_shift(
return mu
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -157,7 +174,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
+class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin):
r"""
The Flux pipeline for text-to-image generation.
@@ -184,9 +201,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
- _optional_components = []
- _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image"]
def __init__(
self,
@@ -200,6 +217,8 @@ def __init__(
controlnet: Union[
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
],
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
@@ -214,15 +233,17 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
controlnet=controlnet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
- )
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
- self.default_sample_size = 64
+ self.default_sample_size = 128
def _get_t5_prompt_embeds(
self,
@@ -399,19 +420,74 @@ def encode_prompt(
return prompt_embeds, pooled_prompt_embeds, text_ids
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
+ negative_prompt=None,
+ negative_prompt_2=None,
prompt_embeds=None,
+ negative_prompt_embeds=None,
pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -439,10 +515,33 @@ def check_inputs(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -450,9 +549,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
@@ -476,13 +575,15 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
- height = height // vae_scale_factor
- width = width // vae_scale_factor
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@@ -498,13 +599,15 @@ def prepare_latents(
generator,
latents=None,
):
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
@@ -516,7 +619,7 @@ def prepare_latents(
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids
@@ -577,10 +680,13 @@ def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
@@ -592,6 +698,12 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -616,10 +728,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -659,6 +771,17 @@ def __call__(
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -707,8 +830,12 @@ def __call__(
prompt_2,
height,
width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -728,9 +855,11 @@ def __call__(
device = self._execution_device
dtype = self.transformer.dtype
+ # 3. Prepare text embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
(
prompt_embeds,
pooled_prompt_embeds,
@@ -745,6 +874,21 @@ def __call__(
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ _,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
# 3. Prepare control image
num_channels_latents = self.transformer.config.in_channels // 4
@@ -764,7 +908,7 @@ def __call__(
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
if self.controlnet.input_hint_block is None:
# vae encode
- control_image = self.vae.encode(control_image).latent_dist.sample()
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
@@ -802,7 +946,7 @@ def __call__(
if self.controlnet.nets[0].input_hint_block is None:
# vae encode
- control_image_ = self.vae.encode(control_image_).latent_dist.sample()
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
@@ -849,21 +993,20 @@ def __call__(
)
# 5. Prepare timesteps
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
- timesteps,
- sigmas,
+ sigmas=sigmas,
mu=mu,
)
@@ -879,12 +1022,43 @@ def __call__(
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -940,6 +1114,25 @@ def __call__(
controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ controlnet_block_samples=controlnet_block_samples,
+ controlnet_single_block_samples=controlnet_single_block_samples,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
+ )[0]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -957,6 +1150,7 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
index 7b40ddfca79a..ddd5372b4dd8 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
@@ -13,7 +13,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
@@ -87,7 +87,7 @@ def calculate_shift(
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
- max_shift: float = 1.16,
+ max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
@@ -198,7 +198,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
- _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image"]
def __init__(
self,
@@ -227,14 +227,14 @@ def __init__(
scheduler=scheduler,
controlnet=controlnet,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
- )
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
- self.default_sample_size = 64
+ self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
@@ -453,8 +453,10 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -493,9 +495,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
@@ -519,17 +521,18 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
- height = height // vae_scale_factor
- width = width // vae_scale_factor
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
- # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
image,
@@ -549,11 +552,12 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
-
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
@@ -639,7 +643,7 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
@@ -678,8 +682,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 28):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
control_mode (`int` or `List[int]`, *optional*):
@@ -794,7 +800,7 @@ def __call__(
)
height, width = control_image.shape[-2:]
- control_image = self.vae.encode(control_image).latent_dist.sample()
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image.shape[2:]
@@ -825,7 +831,7 @@ def __call__(
)
height, width = control_image_.shape[-2:]
- control_image_ = self.vae.encode(control_image_).latent_dist.sample()
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image_.shape[2:]
@@ -851,27 +857,25 @@ def __call__(
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
- timesteps,
- sigmas,
+ sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
-
latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
@@ -903,9 +907,12 @@ def __call__(
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- guidance = (
- torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
- )
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
+ else:
+ use_guidance = self.controlnet.config.guidance_embeds
+
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
if isinstance(controlnet_keep[i], list):
@@ -965,6 +972,7 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ control_image = callback_outputs.pop("control_image", control_image)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index 46784f2d46d1..bff625367bc9 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -14,7 +14,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
-from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
+from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
@@ -89,7 +89,7 @@ def calculate_shift(
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
- max_shift: float = 1.16,
+ max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
@@ -200,7 +200,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
- _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image", "mask", "masked_image_latents"]
def __init__(
self,
@@ -230,13 +230,14 @@ def __init__(
controlnet=controlnet,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
- )
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
- vae_scale_factor=self.vae_scale_factor,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
@@ -244,7 +245,7 @@ def __init__(
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
- self.default_sample_size = 64
+ self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
@@ -467,8 +468,10 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -520,9 +523,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
@@ -546,17 +549,18 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
- height = height // vae_scale_factor
- width = width // vae_scale_factor
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
- # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
def prepare_latents(
self,
image,
@@ -576,11 +580,12 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
-
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
@@ -608,7 +613,6 @@ def prepare_latents(
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, noise, image_latents, latent_image_ids
- # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
def prepare_mask_latents(
self,
mask,
@@ -622,8 +626,10 @@ def prepare_mask_latents(
device,
generator,
):
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
@@ -661,7 +667,6 @@ def prepare_mask_latents(
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
-
masked_image_latents = self._pack_latents(
masked_image_latents,
batch_size,
@@ -744,7 +749,7 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.6,
padding_mask_crop: Optional[int] = None,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
num_inference_steps: int = 28,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
@@ -791,8 +796,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 28):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
@@ -921,8 +928,8 @@ def __call__(
if isinstance(self.controlnet, FluxControlNetModel):
control_image = self.prepare_image(
image=control_image,
- width=height,
- height=width,
+ width=width,
+ height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
@@ -930,19 +937,22 @@ def __call__(
)
height, width = control_image.shape[-2:]
- # vae encode
- control_image = self.vae.encode(control_image).latent_dist.sample()
- control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
-
- # pack
- height_control_image, width_control_image = control_image.shape[2:]
- control_image = self._pack_latents(
- control_image,
- batch_size * num_images_per_prompt,
- num_channels_latents,
- height_control_image,
- width_control_image,
- )
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
+ controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
+ if self.controlnet.input_hint_block is None:
+ # vae encode
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
# set control mode
if control_mode is not None:
@@ -952,7 +962,9 @@ def __call__(
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
- for control_image_ in control_image:
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
+ controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
+ for i, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
@@ -964,19 +976,20 @@ def __call__(
)
height, width = control_image_.shape[-2:]
- # vae encode
- control_image_ = self.vae.encode(control_image_).latent_dist.sample()
- control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
-
- # pack
- height_control_image, width_control_image = control_image_.shape[2:]
- control_image_ = self._pack_latents(
- control_image_,
- batch_size * num_images_per_prompt,
- num_channels_latents,
- height_control_image,
- width_control_image,
- )
+ if self.controlnet.nets[0].input_hint_block is None:
+ # vae encode
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ # pack
+ height_control_image, width_control_image = control_image_.shape[2:]
+ control_image_ = self._pack_latents(
+ control_image_,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
control_images.append(control_image_)
@@ -995,21 +1008,22 @@ def __call__(
# 6. Prepare timesteps
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (
+ int(global_width) // self.vae_scale_factor // 2
+ )
mu = calculate_shift(
image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
- timesteps,
- sigmas,
+ sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
@@ -1078,7 +1092,11 @@ def __call__(
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# predict the noise residual
- if self.controlnet.config.guidance_embeds:
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
+ else:
+ use_guidance = self.controlnet.config.guidance_embeds
+ if use_guidance:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
@@ -1125,6 +1143,7 @@ def __call__(
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]
# compute the previous noisy sample x_t -> x_t-1
@@ -1157,6 +1176,9 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ control_image = callback_outputs.pop("control_image", control_image)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
new file mode 100644
index 000000000000..1816b3ca6d9b
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
@@ -0,0 +1,968 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxFillPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png")
+ >>> mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png")
+
+ >>> pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
+ >>> pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU
+
+ >>> image = pipe(
+ ... prompt="a white paper cup",
+ ... image=image,
+ ... mask_image=mask,
+ ... height=1632,
+ ... width=1232,
+ ... guidance_scale=30,
+ ... num_inference_steps=50,
+ ... max_sequence_length=512,
+ ... generator=torch.Generator("cpu").manual_seed(0),
+ ... ).images[0]
+ >>> image.save("flux_fill.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FluxFillPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ The Flux Fill pipeline for image inpainting/outpainting.
+
+ Reference: https://blackforestlabs.ai/flux-1-tools/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # 1. calculate the height and width of the latents
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ # 2. encode the masked image
+ if masked_image.shape[1] == num_channels_latents:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ batch_size = batch_size * num_images_per_prompt
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # 4. pack the masked_image_latents
+ # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ # 5.resize mask to latents shape we we concatenate the mask to the latents
+ mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed)
+ mask = mask.view(
+ batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor
+ ) # batch_size, height, 8, width, 8
+ mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width
+ mask = mask.reshape(
+ batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width
+ ) # batch_size, 8*8, height, width
+
+ # 6. pack the mask:
+ # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2
+ mask = self._pack_latents(
+ mask,
+ batch_size,
+ self.vae_scale_factor * self.vae_scale_factor,
+ height,
+ width,
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ image=None,
+ mask_image=None,
+ masked_image_latents=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ if image is not None and masked_image_latents is not None:
+ raise ValueError(
+ "Please provide either `image` or `masked_image_latents`, `masked_image_latents` should not be passed."
+ )
+
+ if image is not None and mask_image is None:
+ raise ValueError("Please provide `mask_image` when passing `image`.")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: Optional[torch.FloatTensor] = None,
+ mask_image: Optional[torch.FloatTensor] = None,
+ masked_image_latents: Optional[torch.FloatTensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 30.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
+ latents tensor will ge generated by `mask_image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 30.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ image=image,
+ mask_image=mask_image,
+ masked_image_latents=masked_image_latents,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare prompt embeddings
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare mask and masked image latents
+ if masked_image_latents is not None:
+ masked_image_latents = masked_image_latents.to(latents.device)
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
+
+ masked_image = image * (1 - mask_image)
+ masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
+
+ height, width = image.shape[-2:]
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_image,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+ masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
+
+ # 6. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=torch.cat((latents, masked_image_latents), dim=2),
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # 8. Post-process the image
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
index 112260003ef5..64cd6ac45f1a 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
@@ -17,10 +17,17 @@
import numpy as np
import torch
-from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -77,7 +84,7 @@ def calculate_shift(
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
- max_shift: float = 1.16,
+ max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
@@ -159,7 +166,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
+class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin):
r"""
The Flux pipeline for image inpainting.
@@ -186,8 +193,8 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
@@ -199,6 +206,8 @@ def __init__(
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
@@ -210,15 +219,20 @@ def __init__(
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
- self.default_sample_size = 64
+ self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
@@ -395,6 +409,55 @@ def encode_prompt(
return prompt_embeds, pooled_prompt_embeds, text_ids
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
@@ -429,16 +492,22 @@ def check_inputs(
strength,
height,
width,
+ negative_prompt=None,
+ negative_prompt_2=None,
prompt_embeds=None,
+ negative_prompt_embeds=None,
pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -466,10 +535,33 @@ def check_inputs(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -477,9 +569,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
@@ -503,13 +595,15 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
- height = height // vae_scale_factor
- width = width // vae_scale_factor
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@@ -532,17 +626,21 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
-
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype)
- image_latents = self._encode_vae_image(image=image, generator=generator)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
@@ -581,18 +679,27 @@ def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -629,10 +736,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -654,6 +761,17 @@ def __call__(
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -692,8 +810,12 @@ def __call__(
strength,
height,
width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -719,6 +841,7 @@ def __call__(
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
(
prompt_embeds,
pooled_prompt_embeds,
@@ -733,23 +856,37 @@ def __call__(
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ _,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
# 4.Prepare timesteps
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
- timesteps,
- sigmas,
+ sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
@@ -787,12 +924,43 @@ def __call__(
else:
guidance = None
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
@@ -807,6 +975,22 @@ def __call__(
return_dict=False,
)[0]
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index ae348c0f6421..27b9e0cd45fa 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -18,10 +18,17 @@
import numpy as np
import PIL.Image
import torch
-from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -74,7 +81,7 @@ def calculate_shift(
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
- max_shift: float = 1.16,
+ max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
@@ -156,7 +163,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
+class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterMixin):
r"""
The Flux pipeline for image inpainting.
@@ -183,8 +190,8 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
@@ -196,6 +203,8 @@ def __init__(
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
@@ -207,14 +216,19 @@ def __init__(
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
- vae_scale_factor=self.vae_scale_factor,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
@@ -222,7 +236,7 @@ def __init__(
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
- self.default_sample_size = 64
+ self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
@@ -399,6 +413,55 @@ def encode_prompt(
return prompt_embeds, pooled_prompt_embeds, text_ids
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
@@ -436,8 +499,12 @@ def check_inputs(
height,
width,
output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
prompt_embeds=None,
+ negative_prompt_embeds=None,
pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
max_sequence_length=None,
@@ -445,8 +512,10 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -474,10 +543,33 @@ def check_inputs(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
@@ -498,9 +590,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
@@ -524,13 +616,15 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
- height = height // vae_scale_factor
- width = width // vae_scale_factor
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
@@ -553,14 +647,18 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
-
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype)
- image_latents = self._encode_vae_image(image=image, generator=generator)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
@@ -598,8 +696,10 @@ def prepare_mask_latents(
device,
generator,
):
- height = 2 * (int(height) // self.vae_scale_factor)
- width = 2 * (int(width) // self.vae_scale_factor)
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
@@ -615,7 +715,9 @@ def prepare_mask_latents(
else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
- masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ masked_image_latents = (
+ masked_image_latents - self.vae.config.shift_factor
+ ) * self.vae.config.scaling_factor
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -637,7 +739,6 @@ def prepare_mask_latents(
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
-
masked_image_latents = self._pack_latents(
masked_image_latents,
batch_size,
@@ -677,6 +778,9 @@ def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
masked_image_latents: PipelineImageInput = None,
@@ -685,13 +789,19 @@ def __call__(
padding_mask_crop: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -745,10 +855,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -770,6 +880,17 @@ def __call__(
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -811,8 +932,12 @@ def __call__(
height,
width,
output_type=output_type,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
padding_mask_crop=padding_mask_crop,
max_sequence_length=max_sequence_length,
@@ -849,6 +974,7 @@ def __call__(
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
(
prompt_embeds,
pooled_prompt_embeds,
@@ -863,23 +989,37 @@ def __call__(
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ _,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
# 4.Prepare timesteps
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
- self.scheduler.config.base_image_seq_len,
- self.scheduler.config.max_image_seq_len,
- self.scheduler.config.base_shift,
- self.scheduler.config.max_shift,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
- timesteps,
- sigmas,
+ sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
@@ -940,12 +1080,43 @@ def __call__(
else:
guidance = None
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
@@ -960,6 +1131,22 @@ def __call__(
return_dict=False,
)[0]
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
new file mode 100644
index 000000000000..f53958df2ed0
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
@@ -0,0 +1,492 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Union
+
+import torch
+from PIL import Image
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ SiglipImageProcessor,
+ SiglipVisionModel,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput
+from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ..pipeline_utils import DiffusionPipeline
+from .modeling_flux import ReduxImageEncoder
+from .pipeline_output import FluxPriorReduxPipelineOutput
+
+
+if is_torch_xla_available():
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxPriorReduxPipeline, FluxPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> device = "cuda"
+ >>> dtype = torch.bfloat16
+
+ >>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
+ >>> repo_base = "black-forest-labs/FLUX.1-dev"
+ >>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
+ >>> pipe = FluxPipeline.from_pretrained(
+ ... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16
+ ... ).to(device)
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
+ ... )
+ >>> pipe_prior_output = pipe_prior_redux(image)
+ >>> images = pipe(
+ ... guidance_scale=2.5,
+ ... num_inference_steps=50,
+ ... generator=torch.Generator("cpu").manual_seed(0),
+ ... **pipe_prior_output,
+ ... ).images
+ >>> images[0].save("flux-redux.png")
+ ```
+"""
+
+
+class FluxPriorReduxPipeline(DiffusionPipeline):
+ r"""
+ The Flux Redux pipeline for image-to-image generation.
+
+ Reference: https://blackforestlabs.ai/flux-1-tools/
+
+ Args:
+ image_encoder ([`SiglipVisionModel`]):
+ SIGLIP vision model to encode the input image.
+ feature_extractor ([`SiglipImageProcessor`]):
+ Image processor for preprocessing images for the SIGLIP model.
+ image_embedder ([`ReduxImageEncoder`]):
+ Redux image encoder to process the SIGLIP embeddings.
+ text_encoder ([`CLIPTextModel`], *optional*):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`], *optional*):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`, *optional*):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`, *optional*):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "image_encoder->image_embedder"
+ _optional_components = [
+ "text_encoder",
+ "tokenizer",
+ "text_encoder_2",
+ "tokenizer_2",
+ ]
+ _callback_tensor_inputs = []
+
+ def __init__(
+ self,
+ image_encoder: SiglipVisionModel,
+ feature_extractor: SiglipImageProcessor,
+ image_embedder: ReduxImageEncoder,
+ text_encoder: CLIPTextModel = None,
+ tokenizer: CLIPTokenizer = None,
+ text_encoder_2: T5EncoderModel = None,
+ tokenizer_2: T5TokenizerFast = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ image_embedder=image_embedder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+
+ def check_inputs(
+ self,
+ image,
+ prompt,
+ prompt_2,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ prompt_embeds_scale=1.0,
+ pooled_prompt_embeds_scale=1.0,
+ ):
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)):
+ raise ValueError(
+ f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images"
+ )
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if isinstance(prompt_embeds_scale, list) and (
+ isinstance(image, list) and len(prompt_embeds_scale) != len(image)
+ ):
+ raise ValueError(
+ f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images"
+ )
+
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+ image = self.feature_extractor.preprocess(
+ images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
+ )
+ image = image.to(device=device, dtype=dtype)
+
+ image_enc_hidden_states = self.image_encoder(**image).last_hidden_state
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return image_enc_hidden_states
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
+ pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
+ return_dict: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
+ make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
+ are not loaded.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`:
+ [`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ image,
+ prompt,
+ prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ prompt_embeds_scale=prompt_embeds_scale,
+ pooled_prompt_embeds_scale=pooled_prompt_embeds_scale,
+ )
+
+ # 2. Define call parameters
+ if image is not None and isinstance(image, Image.Image):
+ batch_size = 1
+ elif image is not None and isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ if prompt is not None and isinstance(prompt, str):
+ prompt = batch_size * [prompt]
+ if isinstance(prompt_embeds_scale, float):
+ prompt_embeds_scale = batch_size * [prompt_embeds_scale]
+ if isinstance(pooled_prompt_embeds_scale, float):
+ pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]
+
+ device = self._execution_device
+
+ # 3. Prepare image embeddings
+ image_latents = self.encode_image(image, device, 1)
+
+ image_embeds = self.image_embedder(image_latents).image_embeds
+ image_embeds = image_embeds.to(device=device)
+
+ # 3. Prepare (dummy) text embeddings
+ if hasattr(self, "text_encoder") and self.text_encoder is not None:
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ _,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=1,
+ max_sequence_length=512,
+ lora_scale=None,
+ )
+ else:
+ if prompt is not None:
+ logger.warning(
+ "prompt input is ignored when text encoders are not loaded to the pipeline. "
+ "Make sure to explicitly load the text encoders to enable prompt input. "
+ )
+ # max_sequence_length is 512, t5 encoder hidden size is 4096
+ prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
+ # pooled_prompt_embeds is 768, clip text encoder hidden size
+ pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
+
+ # scale & concatenate image and text embeddings
+ prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
+
+ prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
+ pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[
+ :, None
+ ]
+
+ # weighted sum
+ prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
+ pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (prompt_embeds, pooled_prompt_embeds)
+
+ return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds)
diff --git a/src/diffusers/pipelines/flux/pipeline_output.py b/src/diffusers/pipelines/flux/pipeline_output.py
index b5d98fb5bf60..388824e89f87 100644
--- a/src/diffusers/pipelines/flux/pipeline_output.py
+++ b/src/diffusers/pipelines/flux/pipeline_output.py
@@ -3,6 +3,7 @@
import numpy as np
import PIL.Image
+import torch
from ...utils import BaseOutput
@@ -19,3 +20,18 @@ class FluxPipelineOutput(BaseOutput):
"""
images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+@dataclass
+class FluxPriorReduxPipelineOutput(BaseOutput):
+ """
+ Output class for Flux Prior Redux pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ prompt_embeds: torch.Tensor
+ pooled_prompt_embeds: torch.Tensor
diff --git a/src/diffusers/pipelines/hunyuan_video/__init__.py b/src/diffusers/pipelines/hunyuan_video/__init__.py
new file mode 100644
index 000000000000..d9cacad24f17
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video/__init__.py
@@ -0,0 +1,52 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
+ _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
+ _import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
+ from .pipeline_hunyuan_video import HunyuanVideoPipeline
+ from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
new file mode 100644
index 000000000000..297d2a9c9396
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
@@ -0,0 +1,804 @@
+# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import HunyuanVideoLoraLoaderMixin
+from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HunyuanVideoPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoTransformer3DModel
+ >>> from diffusers.utils import load_image, export_to_video
+
+ >>> model_id = "hunyuanvideo-community/HunyuanVideo"
+ >>> transformer_model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
+ >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ ... transformer_model_id, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe = HunyuanSkyreelsImageToVideoPipeline.from_pretrained(
+ ... model_id, transformer=transformer, torch_dtype=torch.float16
+ ... )
+ >>> pipe.vae.enable_tiling()
+ >>> pipe.to("cuda")
+
+ >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+ >>> negative_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... num_inference_steps=30,
+ ... true_cfg_scale=6.0,
+ ... guidance_scale=1.0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=15)
+ ```
+"""
+
+
+DEFAULT_PROMPT_TEMPLATE = {
+ "template": (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ ),
+ "crop_start": 95,
+}
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class HunyuanSkyreelsImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using HunyuanVideo.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`LlamaModel`]):
+ [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ tokenizer (`LlamaTokenizer`):
+ Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ transformer ([`HunyuanVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder_2 ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ text_encoder: LlamaModel,
+ tokenizer: LlamaTokenizerFast,
+ transformer: HunyuanVideoTransformer3DModel,
+ vae: AutoencoderKLHunyuanVideo,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds
+ def _get_llama_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_template: Dict[str, Any],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ num_hidden_layers_to_skip: int = 2,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ prompt = [prompt_template["template"].format(p) for p in prompt]
+
+ crop_start = prompt_template.get("crop_start", None)
+ if crop_start is None:
+ prompt_template_input = self.tokenizer(
+ prompt_template["template"],
+ padding="max_length",
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=False,
+ )
+ crop_start = prompt_template_input["input_ids"].shape[-1]
+ # Remove <|eot_id|> token and placeholder {}
+ crop_start -= 2
+
+ max_sequence_length += crop_start
+ text_inputs = self.tokenizer(
+ prompt,
+ max_length=max_sequence_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ )
+ text_input_ids = text_inputs.input_ids.to(device=device)
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
+
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
+
+ if crop_start is not None and crop_start > 0:
+ prompt_embeds = prompt_embeds[:, crop_start:]
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 77,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]] = None,
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ ):
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
+ prompt,
+ prompt_template,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if pooled_prompt_embeds is None:
+ if prompt_2 is None:
+ prompt_2 = prompt
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=77,
+ )
+
+ return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_template=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_template is not None:
+ if not isinstance(prompt_template, dict):
+ raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
+ if "template" not in prompt_template:
+ raise ValueError(
+ f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
+ )
+
+ def prepare_latents(
+ self,
+ image: torch.Tensor,
+ batch_size: int,
+ num_channels_latents: int = 32,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ image = image.unsqueeze(2) # [B, C, 1, H, W]
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
+ ]
+ else:
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
+
+ image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ padding_shape = (batch_size, num_channels_latents, num_latent_frames - 1, latent_height, latent_width)
+
+ latents_padding = torch.zeros(padding_shape, dtype=dtype, device=device)
+ image_latents = torch.cat([image_latents, latents_padding], dim=2)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, dtype=dtype, device=device)
+ else:
+ latents = latents.to(dtype=dtype, device=device)
+
+ return latents, image_latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ sigmas: List[float] = None,
+ true_cfg_scale: float = 6.0,
+ guidance_scale: float = 1.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ max_sequence_length: int = 256,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ height (`int`, defaults to `720`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `129`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
+ CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
+ not applied.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ prompt_template,
+ )
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
+
+ if do_true_cfg:
+ negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ # 5. Prepare latent variables
+ vae_dtype = self.vae.dtype
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, vae_dtype)
+ num_channels_latents = self.transformer.config.in_channels // 2
+ latents, image_latents = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+ latent_image_input = image_latents.to(transformer_dtype)
+
+ # 6. Prepare guidance condition
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ pooled_projections=negative_pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HunyuanVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
new file mode 100644
index 000000000000..3cb91b3782f2
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -0,0 +1,754 @@
+# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import HunyuanVideoLoraLoaderMixin
+from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HunyuanVideoPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+ >>> from diffusers.utils import export_to_video
+
+ >>> model_id = "hunyuanvideo-community/HunyuanVideo"
+ >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
+ >>> pipe.vae.enable_tiling()
+ >>> pipe.to("cuda")
+
+ >>> output = pipe(
+ ... prompt="A cat walks on the grass, realistic",
+ ... height=320,
+ ... width=512,
+ ... num_frames=61,
+ ... num_inference_steps=30,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=15)
+ ```
+"""
+
+
+DEFAULT_PROMPT_TEMPLATE = {
+ "template": (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ ),
+ "crop_start": 95,
+}
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using HunyuanVideo.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`LlamaModel`]):
+ [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ tokenizer (`LlamaTokenizer`):
+ Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ transformer ([`HunyuanVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder_2 ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ text_encoder: LlamaModel,
+ tokenizer: LlamaTokenizerFast,
+ transformer: HunyuanVideoTransformer3DModel,
+ vae: AutoencoderKLHunyuanVideo,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_llama_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_template: Dict[str, Any],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ num_hidden_layers_to_skip: int = 2,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ prompt = [prompt_template["template"].format(p) for p in prompt]
+
+ crop_start = prompt_template.get("crop_start", None)
+ if crop_start is None:
+ prompt_template_input = self.tokenizer(
+ prompt_template["template"],
+ padding="max_length",
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=False,
+ )
+ crop_start = prompt_template_input["input_ids"].shape[-1]
+ # Remove <|eot_id|> token and placeholder {}
+ crop_start -= 2
+
+ max_sequence_length += crop_start
+ text_inputs = self.tokenizer(
+ prompt,
+ max_length=max_sequence_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ )
+ text_input_ids = text_inputs.input_ids.to(device=device)
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
+
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
+
+ if crop_start is not None and crop_start > 0:
+ prompt_embeds = prompt_embeds[:, crop_start:]
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 77,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]] = None,
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ ):
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
+ prompt,
+ prompt_template,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if pooled_prompt_embeds is None:
+ if prompt_2 is None:
+ prompt_2 = prompt
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=77,
+ )
+
+ return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_template=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_template is not None:
+ if not isinstance(prompt_template, dict):
+ raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
+ if "template" not in prompt_template:
+ raise ValueError(
+ f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
+ )
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 32,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Union[str, List[str]] = None,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ num_inference_steps: int = 50,
+ sigmas: List[float] = None,
+ true_cfg_scale: float = 1.0,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ max_sequence_length: int = 256,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ height (`int`, defaults to `720`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `129`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
+ CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
+ not applied.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ prompt_template,
+ )
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
+
+ if do_true_cfg:
+ negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare guidance condition
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ pooled_projections=negative_pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HunyuanVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
new file mode 100644
index 000000000000..774b72e6c7c1
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
@@ -0,0 +1,924 @@
+# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ LlamaTokenizerFast,
+ LlavaForConditionalGeneration,
+)
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import HunyuanVideoLoraLoaderMixin
+from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HunyuanVideoPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel
+ >>> from diffusers.utils import load_image, export_to_video
+
+ >>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch
+ >>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V"
+ >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe = HunyuanVideoImageToVideoPipeline.from_pretrained(
+ ... model_id, transformer=transformer, torch_dtype=torch.float16
+ ... )
+ >>> pipe.vae.enable_tiling()
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A man with short gray hair plays a red electric guitar."
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png"
+ ... )
+
+ >>> # If using hunyuanvideo-community/HunyuanVideo-I2V
+ >>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0]
+
+ >>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch
+ >>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0]
+
+ >>> export_to_video(output, "output.mp4", fps=15)
+ ```
+"""
+
+
+DEFAULT_PROMPT_TEMPLATE = {
+ "template": (
+ "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ ),
+ "crop_start": 103,
+ "image_emb_start": 5,
+ "image_emb_end": 581,
+ "image_emb_len": 576,
+ "double_return_token_id": 271,
+}
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using HunyuanVideo.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`LlavaForConditionalGeneration`]):
+ [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ tokenizer (`LlamaTokenizer`):
+ Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ transformer ([`HunyuanVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder_2 ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ text_encoder: LlavaForConditionalGeneration,
+ tokenizer: LlamaTokenizerFast,
+ transformer: HunyuanVideoTransformer3DModel,
+ vae: AutoencoderKLHunyuanVideo,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ image_processor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ image_processor=image_processor,
+ )
+
+ self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_llama_prompt_embeds(
+ self,
+ image: torch.Tensor,
+ prompt: Union[str, List[str]],
+ prompt_template: Dict[str, Any],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ num_hidden_layers_to_skip: int = 2,
+ image_embed_interleave: int = 2,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_template["template"].format(p) for p in prompt]
+
+ crop_start = prompt_template.get("crop_start", None)
+ if crop_start is None:
+ prompt_template_input = self.tokenizer(
+ prompt_template["template"],
+ padding="max_length",
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=False,
+ )
+ crop_start = prompt_template_input["input_ids"].shape[-1]
+ # Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {}
+ crop_start -= 5
+
+ max_sequence_length += crop_start
+ text_inputs = self.tokenizer(
+ prompt,
+ max_length=max_sequence_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ )
+ text_input_ids = text_inputs.input_ids.to(device=device)
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
+
+ image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
+
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_attention_mask,
+ pixel_values=image_embeds,
+ output_hidden_states=True,
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
+
+ image_emb_len = prompt_template.get("image_emb_len", 576)
+ image_emb_start = prompt_template.get("image_emb_start", 5)
+ image_emb_end = prompt_template.get("image_emb_end", 581)
+ double_return_token_id = prompt_template.get("double_return_token_id", 271)
+
+ if crop_start is not None and crop_start > 0:
+ text_crop_start = crop_start - 1 + image_emb_len
+ batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
+
+ if last_double_return_token_indices.shape[0] == 3:
+ # in case the prompt is too long
+ last_double_return_token_indices = torch.cat(
+ (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]]))
+ )
+ batch_indices = torch.cat((batch_indices, torch.tensor([0])))
+
+ last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[
+ :, -1
+ ]
+ batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1]
+ assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4
+ assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len
+ attention_mask_assistant_crop_start = last_double_return_token_indices - 4
+ attention_mask_assistant_crop_end = last_double_return_token_indices
+
+ prompt_embed_list = []
+ prompt_attention_mask_list = []
+ image_embed_list = []
+ image_attention_mask_list = []
+
+ for i in range(text_input_ids.shape[0]):
+ prompt_embed_list.append(
+ torch.cat(
+ [
+ prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()],
+ prompt_embeds[i, assistant_crop_end[i].item() :],
+ ]
+ )
+ )
+ prompt_attention_mask_list.append(
+ torch.cat(
+ [
+ prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()],
+ prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :],
+ ]
+ )
+ )
+ image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end])
+ image_attention_mask_list.append(
+ torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype)
+ )
+
+ prompt_embed_list = torch.stack(prompt_embed_list)
+ prompt_attention_mask_list = torch.stack(prompt_attention_mask_list)
+ image_embed_list = torch.stack(image_embed_list)
+ image_attention_mask_list = torch.stack(image_attention_mask_list)
+
+ if 0 < image_embed_interleave < 6:
+ image_embed_list = image_embed_list[:, ::image_embed_interleave, :]
+ image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave]
+
+ assert (
+ prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0]
+ and image_embed_list.shape[0] == image_attention_mask_list.shape[0]
+ )
+
+ prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1)
+ prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 77,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ image: torch.Tensor,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]] = None,
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ image_embed_interleave: int = 2,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
+ image,
+ prompt,
+ prompt_template,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=max_sequence_length,
+ image_embed_interleave=image_embed_interleave,
+ )
+
+ if pooled_prompt_embeds is None:
+ if prompt_2 is None:
+ prompt_2 = prompt
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=77,
+ )
+
+ return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_template=None,
+ true_cfg_scale=1.0,
+ guidance_scale=1.0,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_template is not None:
+ if not isinstance(prompt_template, dict):
+ raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
+ if "template" not in prompt_template:
+ raise ValueError(
+ f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
+ )
+
+ if true_cfg_scale > 1.0 and guidance_scale > 1.0:
+ logger.warning(
+ "Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both "
+ "classifier-free guidance and embedded-guidance to be applied. This is not recommended "
+ "as it may lead to higher memory usage, slower inference and potentially worse results."
+ )
+
+ def prepare_latents(
+ self,
+ image: torch.Tensor,
+ batch_size: int,
+ num_channels_latents: int = 32,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ image_condition_type: str = "latent_concat",
+ ) -> torch.Tensor:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+
+ image = image.unsqueeze(2) # [B, C, 1, H, W]
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
+ for i in range(batch_size)
+ ]
+ else:
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
+
+ image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
+ image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ t = torch.tensor([0.999]).to(device=device)
+ latents = latents * t + image_latents * (1 - t)
+
+ if image_condition_type == "token_replace":
+ image_latents = image_latents[:, :, :1]
+
+ return latents, image_latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PIL.Image.Image,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Union[str, List[str]] = None,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ num_inference_steps: int = 50,
+ sigmas: List[float] = None,
+ true_cfg_scale: float = 1.0,
+ guidance_scale: float = 1.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ max_sequence_length: int = 256,
+ image_embed_interleave: Optional[int] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ height (`int`, defaults to `720`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `129`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ guidance_scale (`float`, defaults to `1.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
+ CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
+ not applied.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ prompt_template,
+ true_cfg_scale,
+ guidance_scale,
+ )
+
+ image_condition_type = self.transformer.config.image_condition_type
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ image_embed_interleave = (
+ image_embed_interleave
+ if image_embed_interleave is not None
+ else (
+ 2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1
+ )
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Prepare latent variables
+ vae_dtype = self.vae.dtype
+ image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
+
+ if image_condition_type == "latent_concat":
+ num_channels_latents = (self.transformer.config.in_channels - 1) // 2
+ elif image_condition_type == "token_replace":
+ num_channels_latents = self.transformer.config.in_channels
+
+ latents, image_latents = self.prepare_latents(
+ image_tensor,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ image_condition_type,
+ )
+ if image_condition_type == "latent_concat":
+ image_latents[:, :, 1:] = 0
+ mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
+ mask[:, :, 1:] = 0
+
+ # 4. Encode input prompt
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
+ image=image,
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ image_embed_interleave=image_embed_interleave,
+ )
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
+
+ if do_true_cfg:
+ black_image = PIL.Image.new("RGB", (width, height), 0)
+ negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
+ image=black_image,
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ # 6. Prepare guidance condition
+ guidance = None
+ if self.transformer.config.guidance_embeds:
+ guidance = (
+ torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
+ )
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if image_condition_type == "latent_concat":
+ latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
+ elif image_condition_type == "token_replace":
+ latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ pooled_projections=negative_pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if image_condition_type == "latent_concat":
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+ elif image_condition_type == "token_replace":
+ latents = latents = self.scheduler.step(
+ noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
+ )[0]
+ latents = torch.cat([image_latents, latents], dim=2)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype) / self.vae_scaling_factor
+ video = self.vae.decode(latents, return_dict=False)[0]
+ if image_condition_type == "latent_concat":
+ video = video[:, :, 4:, :, :]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ if image_condition_type == "latent_concat":
+ video = latents[:, :, 1:, :, :]
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HunyuanVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py
new file mode 100644
index 000000000000..c5cb853e3932
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class HunyuanVideoPipelineOutput(BaseOutput):
+ r"""
+ Output class for HunyuanVideo pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
index bda718cb197d..febf2b0392cc 100644
--- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
+++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
@@ -207,8 +207,8 @@ def __init__(
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
- text_encoder_2=T5EncoderModel,
- tokenizer_2=MT5Tokenizer,
+ text_encoder_2: Optional[T5EncoderModel] = None,
+ tokenizer_2: Optional[MT5Tokenizer] = None,
):
super().__init__()
@@ -240,9 +240,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
@@ -798,7 +796,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
- self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
+ self.transformer.inner_dim // self.transformer.num_heads,
+ grid_crops_coords,
+ (grid_height, grid_width),
+ device=device,
+ output_type="pt",
)
style = torch.tensor([0], device=device)
diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
index f528b60e6ed7..58d65a190d5b 100644
--- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
+++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
@@ -27,6 +27,7 @@
from ...schedulers import DDIMScheduler
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -35,8 +36,16 @@
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -133,7 +142,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# `do_resize=False` as we do custom resizing.
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
@@ -711,6 +720,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index b2041e101564..b5f4acf5c05a 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -22,6 +22,7 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -30,8 +31,16 @@
from .text_encoder import MultilingualCLIP
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -385,6 +394,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index fe9909770376..e653b8266f19 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -193,15 +193,15 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_sequential_cpu_offload(self, gpu_id=0):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
- self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
- self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
+ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
+ self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -411,7 +411,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_sequential_cpu_offload(self, gpu_id=0):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -419,8 +419,8 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
- self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
- self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
+ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
+ self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -652,7 +652,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_sequential_cpu_offload(self, gpu_id=0):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -660,8 +660,8 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
- self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
- self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
+ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
+ self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index ef5241fee5d2..5d56efef9287 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -25,6 +25,7 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -33,8 +34,16 @@
from .text_encoder import MultilingualCLIP
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -478,6 +487,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index 778b6e314c0d..cce5f0b3d5bc 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -29,6 +29,7 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -37,8 +38,16 @@
from .text_encoder import MultilingualCLIP
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -613,6 +622,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
index b5152d71cb6b..a348deef8b29 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -24,6 +24,7 @@
from ...schedulers import UnCLIPScheduler
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -31,8 +32,16 @@
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -519,6 +528,9 @@ def __call__(
prev_timestep=prev_timestep,
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
index 471db61556f5..a584674540d8 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
@@ -18,13 +18,21 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -296,6 +304,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
index 0130c3951b38..bada59080c7b 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
@@ -19,14 +19,23 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -297,6 +306,10 @@ def __call__(
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
index 12be1534c642..4f6c4188bd48 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
@@ -22,14 +22,23 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -358,6 +367,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
index 899273a1a736..624748896911 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
@@ -21,13 +21,21 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -372,6 +380,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
index b5ba7a0011a1..482093a4bb29 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
@@ -25,13 +25,21 @@
from ... import __version__
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -526,6 +534,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
index f2134b22b40b..d05a7fbdb1b8 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -7,6 +7,7 @@
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -15,8 +16,16 @@
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -524,6 +533,9 @@ def __call__(
)
text_mask = callback_outputs.pop("text_mask", text_mask)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
index ec6509bb3cb5..56d326e26e6e 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -7,6 +7,7 @@
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -15,8 +16,16 @@
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -538,6 +547,9 @@ def __call__(
prev_timestep=prev_timestep,
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
index 8dbae2a1909a..5309f94a53c8 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
@@ -8,6 +8,7 @@
from ...schedulers import DDPMScheduler
from ...utils import (
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -15,8 +16,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -549,6 +558,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
index 81c45c4fb6f8..fbdad79db445 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
@@ -12,6 +12,7 @@
from ...schedulers import DDPMScheduler
from ...utils import (
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -19,8 +20,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -617,6 +626,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py
index 1d2d07572d68..1fc4c02cc43f 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py
@@ -19,7 +19,7 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
+from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
@@ -121,7 +121,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
+class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin):
r"""
Pipeline for text-to-image generation using Kolors.
@@ -129,8 +129,8 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
@@ -188,12 +188,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
def encode_prompt(
self,
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
index 6ddda7acf2a8..df94ec3f0f24 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
@@ -207,12 +207,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
# Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt
def encode_prompt(
diff --git a/src/diffusers/pipelines/kolors/text_encoder.py b/src/diffusers/pipelines/kolors/text_encoder.py
index 6fb6f18a907a..757569c880c0 100644
--- a/src/diffusers/pipelines/kolors/text_encoder.py
+++ b/src/diffusers/pipelines/kolors/text_encoder.py
@@ -104,13 +104,6 @@ def forward(self, hidden_states: torch.Tensor):
return (self.weight * hidden_states).to(input_dtype)
-def _config_to_kwargs(args):
- common_kwargs = {
- "dtype": args.torch_dtype,
- }
- return common_kwargs
-
-
class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__()
@@ -314,7 +307,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
- **_config_to_kwargs(config),
)
self.core_attention = CoreAttention(config, self.layer_number)
@@ -325,7 +317,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
config.hidden_size,
bias=config.add_bias_linear,
device=device,
- **_config_to_kwargs(config),
)
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
@@ -449,7 +440,6 @@ def __init__(self, config: ChatGLMConfig, device=None):
config.ffn_hidden_size * 2,
bias=self.add_bias,
device=device,
- **_config_to_kwargs(config),
)
def swiglu(x):
@@ -459,9 +449,7 @@ def swiglu(x):
self.activation_func = swiglu
# Project back to h.
- self.dense_4h_to_h = nn.Linear(
- config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
- )
+ self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
def forward(self, hidden_states):
# [s, b, 4hp]
@@ -488,18 +476,14 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
- self.input_layernorm = LayerNormFunc(
- config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
- )
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
# Self attention.
self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
- self.post_attention_layernorm = LayerNormFunc(
- config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
- )
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
# MLP
self.mlp = MLP(config, device=device)
@@ -569,9 +553,7 @@ def build_layer(layer_number):
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
- self.final_layernorm = LayerNormFunc(
- config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
- )
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
self.gradient_checkpointing = False
@@ -590,7 +572,7 @@ def forward(
if not kv_caches:
kv_caches = [None for _ in range(self.num_layers)]
presents = () if use_cache else None
- if self.gradient_checkpointing and self.training:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -604,8 +586,8 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)
layer = self._get_layer(index)
- if self.gradient_checkpointing and self.training:
- layer_ret = torch.utils.checkpoint.checkpoint(
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ layer_ret = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
)
else:
@@ -666,10 +648,6 @@ def get_position_ids(self, input_ids, device):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
return position_ids
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, GLMTransformer):
- module.gradient_checkpointing = value
-
def default_init(cls, *args, **kwargs):
return cls(*args, **kwargs)
@@ -683,9 +661,7 @@ def __init__(self, config: ChatGLMConfig, device=None):
self.hidden_size = config.hidden_size
# Word embeddings (parallel).
- self.word_embeddings = nn.Embedding(
- config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
- )
+ self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
self.fp32_residual_connection = config.fp32_residual_connection
def forward(self, input_ids):
@@ -788,16 +764,13 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
- self.rotary_pos_emb = RotaryEmbedding(
- rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
- )
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method(
nn.Linear,
config.hidden_size,
config.padded_vocab_size,
bias=False,
- dtype=config.torch_dtype,
**init_kwargs,
)
self.pre_seq_len = config.pre_seq_len
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index e985648abace..1c59ca7d6d7c 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,6 +41,13 @@
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -226,7 +234,7 @@ def __init__(
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
@@ -952,6 +960,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
denoised = denoised.to(prompt_embeds.dtype)
if not output_type == "latent":
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
index d110cd464522..a3d9917d3376 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
@@ -29,6 +29,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -39,8 +40,16 @@
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -209,7 +218,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -881,6 +890,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
denoised = denoised.to(prompt_embeds.dtype)
if not output_type == "latent":
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
index f6f3531a8835..c7aa76a01fb8 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -25,10 +25,19 @@
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
class LDMTextToImagePipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using latent diffusion.
@@ -202,6 +211,9 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# scale and decode the image latents with vae
latents = 1 / self.vqvae.config.scaling_factor * latents
image = self.vqvae.decode(latents).sample
@@ -532,10 +544,6 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (LDMBertEncoder,)):
- module.gradient_checkpointing = value
-
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
@@ -675,16 +683,9 @@ def forward(
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
- if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs, output_attentions)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(encoder_layer),
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
index bb72b4d4eb8e..879722e6a0e2 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
@@ -15,11 +15,19 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
-from ...utils import PIL_INTERPOLATION
+from ...utils import PIL_INTERPOLATION, is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
def preprocess(image):
w, h = image.size
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
@@ -174,6 +182,9 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# decode the image latents with the VQVAE
image = self.vqvae.decode(latents).sample
image = torch.clamp(image, -1.0, 1.0)
diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py
index 19c4a6d1ddf9..e9a95e8be45c 100644
--- a/src/diffusers/pipelines/latte/pipeline_latte.py
+++ b/src/diffusers/pipelines/latte/pipeline_latte.py
@@ -30,8 +30,10 @@
from ...utils import (
BACKENDS_MAPPING,
BaseOutput,
+ deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -39,8 +41,16 @@
from ...video_processor import VideoProcessor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -180,7 +190,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
@@ -592,6 +602,10 @@ def do_classifier_free_guidance(self):
def num_timesteps(self):
return self._num_timesteps
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
@property
def interrupt(self):
return self._interrupt
@@ -623,7 +637,7 @@ def __call__(
clean_caption: bool = True,
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
- decode_chunk_size: Optional[int] = None,
+ decode_chunk_size: int = 14,
) -> Union[LattePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -719,6 +733,7 @@ def __call__(
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
+ self._current_timestep = None
self._interrupt = False
# 2. Default height and width to transformer
@@ -780,6 +795,7 @@ def __call__(
if self.interrupt:
continue
+ self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -788,10 +804,11 @@ def __call__(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
+ is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
@@ -800,7 +817,7 @@ def __call__(
# predict noise model_output
noise_pred = self.transformer(
- latent_model_input,
+ hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=current_timestep,
enable_temporal_attentions=enable_temporal_attentions,
@@ -836,8 +853,20 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
- if not output_type == "latents":
- video = self.decode_latents(latents, video_length, decode_chunk_size=14)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latents":
+ deprecation_message = (
+ "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead."
+ )
+ deprecate("output_type_latents", "1.0.0", deprecation_message, standard_warn=False)
+ output_type = "latent"
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
index f0f71080d0a3..bdac47c47ade 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
@@ -19,6 +19,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -29,26 +30,32 @@
from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
- >>> import PIL
- >>> import requests
>>> import torch
- >>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusion
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
... )
+ >>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
- >>> image = load_image(img_url).convert("RGB")
+ >>> image = load_image(img_url).resize((512, 512))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
@@ -152,7 +159,7 @@ def __init__(self, device):
# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
- meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -318,7 +325,7 @@ def __init__(
"The scheduler has been changed to DPMSolverMultistepScheduler."
)
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -332,7 +339,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -361,10 +368,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -391,7 +402,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -706,6 +717,35 @@ def clip_skip(self):
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1182,6 +1222,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
@@ -1271,6 +1314,8 @@ def invert(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
+ if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
+ raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
@@ -1349,6 +1394,9 @@ def invert(
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1)
zs = zs.flip(0)
self.zs = zs
@@ -1360,6 +1408,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
+ height, width = image.shape[-2:]
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(
+ "Image height and width must be a factor of 32. "
+ "Consider down-sampling the input using the `height` and `width` parameters"
+ )
resized = self.image_processor.postprocess(image=image, output_type="pil")
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
index 834445bfcd06..cad7d8a66a08 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
@@ -72,25 +72,18 @@
Examples:
```py
>>> import torch
- >>> import PIL
- >>> import requests
- >>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusionXL
+ >>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
- ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16
... )
+ >>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
-
- >>> def download_image(url):
- ... response = requests.get(url)
- ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
-
-
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
- >>> image = download_image(img_url)
+ >>> image = load_image(img_url).resize((1024, 1024))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
@@ -197,7 +190,7 @@ def __init__(self, device):
# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
- meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -379,7 +372,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler):
@@ -391,7 +384,11 @@ def __init__(
"The scheduler has been changed to DPMSolverMultistepScheduler."
)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -768,6 +765,35 @@ def denoising_end(self):
def num_timesteps(self):
return self._num_timesteps
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
def prepare_unet(self, attention_store, PnP: bool = False):
attn_procs = {}
@@ -1401,6 +1427,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
+ height, width = image.shape[-2:]
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(
+ "Image height and width must be a factor of 32. "
+ "Consider down-sampling the input using the `height` and `width` parameters"
+ )
resized = self.image_processor.postprocess(image=image, output_type="pil")
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
@@ -1439,6 +1471,10 @@ def invert(
crops_coords_top_left: Tuple[int, int] = (0, 0),
num_zero_noise_steps: int = 3,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ resize_mode: Optional[str] = "default",
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
):
r"""
The function to the pipeline for image inversion as described by the [LEDITS++
@@ -1486,6 +1522,8 @@ def invert(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
+ if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
+ raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
@@ -1510,7 +1548,14 @@ def invert(
do_classifier_free_guidance = source_guidance_scale > 1.0
# 1. prepare image
- x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
+ x0, resized = self.encode_image(
+ image,
+ dtype=self.text_encoder_2.dtype,
+ height=height,
+ width=width,
+ resize_mode=resize_mode,
+ crops_coords=crops_coords,
+ )
width = x0.shape[2] * self.vae_scale_factor
height = x0.shape[3] * self.vae_scale_factor
self.size = (height, width)
diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py
new file mode 100644
index 000000000000..199e730d9b4d
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/__init__.py
@@ -0,0 +1,52 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_ltx"] = ["LTXPipeline"]
+ _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
+ _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_ltx import LTXPipeline
+ from .pipeline_ltx_condition import LTXConditionPipeline
+ from .pipeline_ltx_image2video import LTXImageToVideoPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
new file mode 100644
index 000000000000..6f3faed8ff72
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -0,0 +1,801 @@
+# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
+from ...models.autoencoders import AutoencoderKLLTXVideo
+from ...models.transformers import LTXVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LTXPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import LTXPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> video = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=704,
+ ... height=480,
+ ... num_frames=161,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=24)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation.
+
+ Reference: https://github.com/Lightricks/LTX-Video
+
+ Args:
+ transformer ([`LTXVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLLTXVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTXVideo,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: LTXVideoTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
+ )
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ @staticmethod
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+ @staticmethod
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 128,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ height = height // self.vae_spatial_compression_ratio
+ width = width // self.vae_spatial_compression_ratio
+ num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ frame_rate: int = 25,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `512`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, defaults to `704`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, defaults to `161`):
+ The number of video frames to generate
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, defaults to `3 `):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ decode_timestep (`float`, defaults to `0.0`):
+ The timestep at which generated video is decoded.
+ decode_noise_scale (`float`, defaults to `None`):
+ The interpolation factor between random noise and denoised latents at the decode timestep.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `128 `):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ video_sequence_length = latent_num_frames * latent_height * latent_width
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare micro-conditions
+ rope_interpolation_scale = (
+ self.vae_temporal_compression_ratio / frame_rate,
+ self.vae_spatial_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ )
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
new file mode 100644
index 000000000000..ef1fd568397f
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -0,0 +1,1181 @@
+# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import PIL.Image
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
+from ...models.autoencoders import AutoencoderKLLTXVideo
+from ...models.transformers import LTXVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LTXPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition
+ >>> from diffusers.utils import export_to_video, load_video, load_image
+
+ >>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # Load input image and video
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
+ ... )
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
+ ... )
+
+ >>> # Create conditioning objects
+ >>> condition1 = LTXVideoCondition(
+ ... image=image,
+ ... frame_index=0,
+ ... )
+ >>> condition2 = LTXVideoCondition(
+ ... video=video,
+ ... frame_index=80,
+ ... )
+
+ >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> # Generate video
+ >>> generator = torch.Generator("cuda").manual_seed(0)
+ >>> video = pipe(
+ ... conditions=[condition1, condition2],
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=768,
+ ... height=512,
+ ... num_frames=161,
+ ... num_inference_steps=40,
+ ... generator=generator,
+ ... ).frames[0]
+
+ >>> export_to_video(video, "output.mp4", fps=24)
+ ```
+"""
+
+
+@dataclass
+class LTXVideoCondition:
+ """
+ Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames.
+
+ Attributes:
+ image (`PIL.Image.Image`):
+ The image to condition the video on.
+ video (`List[PIL.Image.Image]`):
+ The video to condition the video on.
+ frame_index (`int`):
+ The frame index at which the image or video will conditionally effect the video generation.
+ strength (`float`, defaults to `1.0`):
+ The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied.
+ """
+
+ image: Optional[PIL.Image.Image] = None
+ video: Optional[List[PIL.Image.Image]] = None
+ frame_index: int = 0
+ strength: float = 1.0
+
+
+# from LTX-Video/ltx_video/schedulers/rf.py
+def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None):
+ if linear_steps is None:
+ linear_steps = num_steps // 2
+ if num_steps < 2:
+ return torch.tensor([1.0])
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
+ quadratic_steps = num_steps - linear_steps
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
+ const = quadratic_coef * (linear_steps**2)
+ quadratic_sigma_schedule = [
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
+ ]
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
+ return torch.tensor(sigma_schedule[:-1])
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation.
+
+ Reference: https://github.com/Lightricks/LTX-Video
+
+ Args:
+ transformer ([`LTXVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLLTXVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTXVideo,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: LTXVideoTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
+ )
+
+ self.default_height = 512
+ self.default_width = 704
+ self.default_frames = 121
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ conditions,
+ image,
+ video,
+ frame_index,
+ strength,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ if conditions is not None and (image is not None or video is not None):
+ raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
+
+ if conditions is None and (image is None and video is None):
+ raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")
+
+ if conditions is None:
+ if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
+ raise ValueError(
+ "If `conditions` is not provided, `image` and `frame_index` must be of the same length."
+ )
+ elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength):
+ raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.")
+ elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index):
+ raise ValueError(
+ "If `conditions` is not provided, `video` and `frame_index` must be of the same length."
+ )
+ elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
+ raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
+
+ @staticmethod
+ def _prepare_video_ids(
+ batch_size: int,
+ num_frames: int,
+ height: int,
+ width: int,
+ patch_size: int = 1,
+ patch_size_t: int = 1,
+ device: torch.device = None,
+ ) -> torch.Tensor:
+ latent_sample_coords = torch.meshgrid(
+ torch.arange(0, num_frames, patch_size_t, device=device),
+ torch.arange(0, height, patch_size, device=device),
+ torch.arange(0, width, patch_size, device=device),
+ indexing="ij",
+ )
+ latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
+ latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
+ latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)
+
+ return latent_coords
+
+ @staticmethod
+ def _scale_video_ids(
+ video_ids: torch.Tensor,
+ scale_factor: int = 32,
+ scale_factor_t: int = 8,
+ frame_index: int = 0,
+ device: torch.device = None,
+ ) -> torch.Tensor:
+ scaled_latent_coords = (
+ video_ids
+ * torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None]
+ )
+ scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0)
+ scaled_latent_coords[:, 0] += frame_index
+
+ return scaled_latent_coords
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int):
+ """
+ Trim a conditioning sequence to the allowed number of frames.
+
+ Args:
+ start_frame (int): The target frame number of the first frame in the sequence.
+ sequence_num_frames (int): The number of frames in the sequence.
+ target_num_frames (int): The target number of frames in the generated video.
+ Returns:
+ int: updated sequence length
+ """
+ scale_factor = self.vae_temporal_compression_ratio
+ num_frames = min(sequence_num_frames, target_num_frames - start_frame)
+ # Trim down to a multiple of temporal_scale_factor frames plus 1
+ num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
+ return num_frames
+
+ @staticmethod
+ def add_noise_to_image_conditioning_latents(
+ t: float,
+ init_latents: torch.Tensor,
+ latents: torch.Tensor,
+ noise_scale: float,
+ conditioning_mask: torch.Tensor,
+ generator,
+ eps=1e-6,
+ ):
+ """
+ Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially
+ when conditioned on a single frame.
+ """
+ noise = randn_tensor(
+ latents.shape,
+ generator=generator,
+ device=latents.device,
+ dtype=latents.dtype,
+ )
+ # Add noise only to hard-conditioning latents (conditioning_mask = 1.0)
+ need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1)
+ noised_latents = init_latents + noise_scale * noise * (t**2)
+ latents = torch.where(need_to_noise, noised_latents, latents)
+ return latents
+
+ def prepare_latents(
+ self,
+ conditions: List[torch.Tensor],
+ condition_strength: List[float],
+ condition_frame_index: List[int],
+ batch_size: int = 1,
+ num_channels_latents: int = 128,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ num_prefix_latent_frames: int = 2,
+ generator: Optional[torch.Generator] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> None:
+ num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)
+
+ extra_conditioning_latents = []
+ extra_conditioning_video_ids = []
+ extra_conditioning_mask = []
+ extra_conditioning_num_latents = 0
+ for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
+ condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
+ condition_latents = self._normalize_latents(
+ condition_latents, self.vae.latents_mean, self.vae.latents_std
+ ).to(device, dtype=dtype)
+
+ num_data_frames = data.size(2)
+ num_cond_frames = condition_latents.size(2)
+
+ if frame_index == 0:
+ latents[:, :, :num_cond_frames] = torch.lerp(
+ latents[:, :, :num_cond_frames], condition_latents, strength
+ )
+ condition_latent_frames_mask[:, :num_cond_frames] = strength
+
+ else:
+ if num_data_frames > 1:
+ if num_cond_frames < num_prefix_latent_frames:
+ raise ValueError(
+ f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
+ )
+
+ if num_cond_frames > num_prefix_latent_frames:
+ start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
+ end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
+ latents[:, :, start_frame:end_frame] = torch.lerp(
+ latents[:, :, start_frame:end_frame],
+ condition_latents[:, :, num_prefix_latent_frames:],
+ strength,
+ )
+ condition_latent_frames_mask[:, start_frame:end_frame] = strength
+ condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
+
+ noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
+ condition_latents = torch.lerp(noise, condition_latents, strength)
+
+ condition_video_ids = self._prepare_video_ids(
+ batch_size,
+ condition_latents.size(2),
+ latent_height,
+ latent_width,
+ patch_size=self.transformer_spatial_patch_size,
+ patch_size_t=self.transformer_temporal_patch_size,
+ device=device,
+ )
+ condition_video_ids = self._scale_video_ids(
+ condition_video_ids,
+ scale_factor=self.vae_spatial_compression_ratio,
+ scale_factor_t=self.vae_temporal_compression_ratio,
+ frame_index=frame_index,
+ device=device,
+ )
+ condition_latents = self._pack_latents(
+ condition_latents,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ condition_conditioning_mask = torch.full(
+ condition_latents.shape[:2], strength, device=device, dtype=dtype
+ )
+
+ extra_conditioning_latents.append(condition_latents)
+ extra_conditioning_video_ids.append(condition_video_ids)
+ extra_conditioning_mask.append(condition_conditioning_mask)
+ extra_conditioning_num_latents += condition_latents.size(1)
+
+ video_ids = self._prepare_video_ids(
+ batch_size,
+ num_latent_frames,
+ latent_height,
+ latent_width,
+ patch_size_t=self.transformer_temporal_patch_size,
+ patch_size=self.transformer_spatial_patch_size,
+ device=device,
+ )
+ conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
+ video_ids = self._scale_video_ids(
+ video_ids,
+ scale_factor=self.vae_spatial_compression_ratio,
+ scale_factor_t=self.vae_temporal_compression_ratio,
+ frame_index=0,
+ device=device,
+ )
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+
+ if len(extra_conditioning_latents) > 0:
+ latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
+ video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
+ conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
+
+ return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None,
+ image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
+ video: List[PipelineImageInput] = None,
+ frame_index: Union[int, List[int]] = 0,
+ strength: Union[float, List[float]] = 1.0,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ frame_rate: int = 25,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3,
+ image_cond_noise_scale: float = 0.15,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ conditions (`List[LTXVideoCondition], *optional*`):
+ The list of frame-conditioning items for the video generation.If not provided, conditions will be
+ created using `image`, `video`, `frame_index` and `strength`.
+ image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
+ The image or images to condition the video generation. If not provided, one has to pass `video` or
+ `conditions`.
+ video (`List[PipelineImageInput]`, *optional*):
+ The video to condition the video generation. If not provided, one has to pass `image` or `conditions`.
+ frame_index (`int` or `List[int]`, *optional*):
+ The frame index or frame indices at which the image or video will conditionally effect the video
+ generation. If not provided, one has to pass `conditions`.
+ strength (`float` or `List[float]`, *optional*):
+ The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `512`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, defaults to `704`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, defaults to `161`):
+ The number of video frames to generate
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, defaults to `3 `):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ decode_timestep (`float`, defaults to `0.0`):
+ The timestep at which generated video is decoded.
+ decode_noise_scale (`float`, defaults to `None`):
+ The interpolation factor between random noise and denoised latents at the decode timestep.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `128 `):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ if latents is not None:
+ raise ValueError("Passing latents is not yet supported.")
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ conditions=conditions,
+ image=image,
+ video=video,
+ frame_index=frame_index,
+ strength=strength,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if conditions is not None:
+ if not isinstance(conditions, list):
+ conditions = [conditions]
+
+ strength = [condition.strength for condition in conditions]
+ frame_index = [condition.frame_index for condition in conditions]
+ image = [condition.image for condition in conditions]
+ video = [condition.video for condition in conditions]
+ else:
+ if not isinstance(image, list):
+ image = [image]
+ num_conditions = 1
+ elif isinstance(image, list):
+ num_conditions = len(image)
+ if not isinstance(video, list):
+ video = [video]
+ num_conditions = 1
+ elif isinstance(video, list):
+ num_conditions = len(video)
+
+ if not isinstance(frame_index, list):
+ frame_index = [frame_index] * num_conditions
+ if not isinstance(strength, list):
+ strength = [strength] * num_conditions
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ vae_dtype = self.vae.dtype
+
+ conditioning_tensors = []
+ for condition_image, condition_video, condition_frame_index, condition_strength in zip(
+ image, video, frame_index, strength
+ ):
+ if condition_image is not None:
+ condition_tensor = (
+ self.video_processor.preprocess(condition_image, height, width)
+ .unsqueeze(2)
+ .to(device, dtype=vae_dtype)
+ )
+ elif condition_video is not None:
+ condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
+ num_frames_input = condition_tensor.size(2)
+ num_frames_output = self.trim_conditioning_sequence(
+ condition_frame_index, num_frames_input, num_frames
+ )
+ condition_tensor = condition_tensor[:, :, :num_frames_output]
+ condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
+ else:
+ raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
+
+ if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
+ raise ValueError(
+ f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
+ f"but got {condition_tensor.size(2)} frames."
+ )
+ conditioning_tensors.append(condition_tensor)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
+ conditioning_tensors,
+ strength,
+ frame_index,
+ batch_size=batch_size * num_videos_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ video_coords = video_coords.float()
+ video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
+
+ init_latents = latents.clone()
+
+ if self.do_classifier_free_guidance:
+ video_coords = torch.cat([video_coords, video_coords], dim=0)
+
+ # 5. Prepare timesteps
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ sigmas = linear_quadratic_schedule(num_inference_steps)
+ timesteps = sigmas * 1000
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps=timesteps,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ if image_cond_noise_scale > 0:
+ # Add timestep-dependent noise to the hard-conditioning latents
+ # This helps with motion continuity, especially when conditioned on a single frame
+ latents = self.add_noise_to_image_conditioning_latents(
+ t / 1000.0,
+ init_latents,
+ latents,
+ image_cond_noise_scale,
+ conditioning_mask,
+ generator,
+ )
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ conditioning_mask_model_input = (
+ torch.cat([conditioning_mask, conditioning_mask])
+ if self.do_classifier_free_guidance
+ else conditioning_mask
+ )
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
+ timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ video_coords=video_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ timestep, _ = timestep.chunk(2)
+
+ denoised_latents = self.scheduler.step(
+ -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
+ )[0]
+ tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
+ latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ latents = latents[:, extra_conditioning_num_latents:]
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
new file mode 100644
index 000000000000..1ae67967c6f5
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -0,0 +1,899 @@
+# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
+from ...models.autoencoders import AutoencoderKLLTXVideo
+from ...models.transformers import LTXVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LTXPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import LTXImageToVideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
+ ... )
+ >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene."
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+
+ >>> video = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... width=704,
+ ... height=480,
+ ... num_frames=161,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=24)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation.
+
+ Reference: https://github.com/Lightricks/LTX-Video
+
+ Args:
+ transformer ([`LTXVideoTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLLTXVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLLTXVideo,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: LTXVideoTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
+ )
+
+ self.default_height = 512
+ self.default_width = 704
+ self.default_frames = 121
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
+ def _unpack_latents(
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
+ ) -> torch.Tensor:
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
+ # what happens in the `_pack_latents` method.
+ batch_size = latents.size(0)
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_channels_latents: int = 128,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ height = height // self.vae_spatial_compression_ratio
+ width = width // self.vae_spatial_compression_ratio
+ num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
+ mask_shape = (batch_size, 1, num_frames, height, width)
+
+ if latents is not None:
+ conditioning_mask = latents.new_zeros(mask_shape)
+ conditioning_mask[:, :, 0] = 1.0
+ conditioning_mask = self._pack_latents(
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ ).squeeze(-1)
+ if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
+ raise ValueError(
+ f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
+ )
+ return latents.to(device=device, dtype=dtype), conditioning_mask
+
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i])
+ for i in range(batch_size)
+ ]
+ else:
+ init_latents = [
+ retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image
+ ]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+ init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
+ init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
+ conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
+ conditioning_mask[:, :, 0] = 1.0
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
+
+ conditioning_mask = self._pack_latents(
+ conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ ).squeeze(-1)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+
+ return latents, conditioning_mask
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 704,
+ num_frames: int = 161,
+ frame_rate: int = 25,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 3,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `512`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, defaults to `704`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, defaults to `161`):
+ The number of video frames to generate
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, defaults to `3 `):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ decode_timestep (`float`, defaults to `0.0`):
+ The timestep at which generated video is decoded.
+ decode_noise_scale (`float`, defaults to `None`):
+ The interpolation factor between random noise and denoised latents at the decode timestep.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `128 `):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+ self._current_timestep = None
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare latent variables
+ if latents is None:
+ image = self.video_processor.preprocess(image, height=height, width=width)
+ image = image.to(device=device, dtype=prompt_embeds.dtype)
+
+ num_channels_latents = self.transformer.config.in_channels
+ latents, conditioning_mask = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ if self.do_classifier_free_guidance:
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
+
+ # 5. Prepare timesteps
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ video_sequence_length = latent_num_frames * latent_height * latent_width
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ mu = calculate_shift(
+ video_sequence_length,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare micro-conditions
+ rope_interpolation_scale = (
+ self.vae_temporal_compression_ratio / frame_rate,
+ self.vae_spatial_compression_ratio,
+ self.vae_spatial_compression_ratio,
+ )
+
+ # 7. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+ timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ timestep, _ = timestep.chunk(2)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ noise_pred = self._unpack_latents(
+ noise_pred,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+
+ noise_pred = noise_pred[:, :, 1:]
+ noise_latents = latents[:, :, 1:]
+ pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
+
+ latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
+ latents = self._pack_latents(
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
+ )
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = self._unpack_latents(
+ latents,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/ltx/pipeline_output.py b/src/diffusers/pipelines/ltx/pipeline_output.py
new file mode 100644
index 000000000000..36ec3ea884a2
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class LTXPipelineOutput(BaseOutput):
+ r"""
+ Output class for LTX pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/lumina/__init__.py b/src/diffusers/pipelines/lumina/__init__.py
index ca1396359721..a19dc7e94641 100644
--- a/src/diffusers/pipelines/lumina/__init__.py
+++ b/src/diffusers/pipelines/lumina/__init__.py
@@ -22,7 +22,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
- _import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"]
+ _import_structure["pipeline_lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -32,7 +32,7 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
- from .pipeline_lumina import LuminaText2ImgPipeline
+ from .pipeline_lumina import LuminaPipeline, LuminaText2ImgPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index 018f2e8bf1bc..816213f105cb 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -17,11 +17,12 @@
import math
import re
import urllib.parse as ul
-from typing import List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
-from transformers import AutoModel, AutoTokenizer
+from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL
from ...models.embeddings import get_2d_rotary_pos_embed_lumina
@@ -29,8 +30,10 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
+ deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -38,8 +41,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -50,11 +61,9 @@
Examples:
```py
>>> import torch
- >>> from diffusers import LuminaText2ImgPipeline
+ >>> from diffusers import LuminaPipeline
- >>> pipe = LuminaText2ImgPipeline.from_pretrained(
- ... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
- ... )
+ >>> pipe = LuminaPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16)
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
@@ -124,7 +133,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class LuminaText2ImgPipeline(DiffusionPipeline):
+class LuminaPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Lumina-T2I.
@@ -134,13 +143,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
- text_encoder ([`AutoModel`]):
- Frozen text-encoder. Lumina-T2I uses
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
- [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
- tokenizer (`AutoModel`):
- Tokenizer of class
- [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
+ text_encoder ([`GemmaPreTrainedModel`]):
+ Frozen Gemma text-encoder.
+ tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
+ Gemma tokenizer.
transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
@@ -165,14 +171,18 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ ]
def __init__(
self,
transformer: LuminaNextDiT2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
- text_encoder: AutoModel,
- tokenizer: AutoTokenizer,
+ text_encoder: GemmaPreTrainedModel,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
):
super().__init__()
@@ -386,9 +396,19 @@ def check_inputs(
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
):
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -617,7 +637,6 @@ def __call__(
width: Optional[int] = None,
height: Optional[int] = None,
num_inference_steps: int = 30,
- timesteps: List[int] = None,
guidance_scale: float = 4.0,
negative_prompt: Union[str, List[str]] = None,
sigmas: List[float] = None,
@@ -634,6 +653,10 @@ def __call__(
max_sequence_length: int = 256,
scaling_watershed: Optional[float] = 1.0,
proportional_attn: Optional[bool] = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -649,10 +672,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 30):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
@@ -729,7 +748,11 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
+
+ self._guidance_scale = guidance_scale
+
cross_attention_kwargs = {}
# 2. Define call parameters
@@ -776,9 +799,7 @@ def __call__(
prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler, num_inference_steps, device, timesteps, sigmas
- )
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
@@ -793,6 +814,8 @@ def __call__(
latents,
)
+ self._num_timesteps = len(timesteps)
+
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -804,10 +827,11 @@ def __call__(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
+ is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor(
[current_timestep],
dtype=dtype,
@@ -881,6 +905,18 @@ def __call__(
progress_bar.update()
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
@@ -895,3 +931,23 @@ def __call__(
return (image,)
return ImagePipelineOutput(images=image)
+
+
+class LuminaText2ImgPipeline(LuminaPipeline):
+ def __init__(
+ self,
+ transformer: LuminaNextDiT2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: GemmaPreTrainedModel,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ ):
+ deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead."
+ deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message)
+ super().__init__(
+ transformer=transformer,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
diff --git a/src/diffusers/pipelines/lumina2/__init__.py b/src/diffusers/pipelines/lumina2/__init__.py
new file mode 100644
index 000000000000..b1d6bfeb0d58
--- /dev/null
+++ b/src/diffusers/pipelines/lumina2/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
new file mode 100644
index 000000000000..e0905a2f131f
--- /dev/null
+++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
@@ -0,0 +1,790 @@
+# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import Lumina2LoraLoaderMixin
+from ...models import AutoencoderKL
+from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import Lumina2Pipeline
+
+ >>> pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Lumina-T2I.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Gemma2PreTrainedModel`]):
+ Frozen Gemma2 text-encoder.
+ tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
+ Gemma tokenizer.
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ transformer: Lumina2Transformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: Gemma2PreTrainedModel,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 8
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 128
+ )
+ self.default_image_size = self.default_sample_size * self.vae_scale_factor
+ self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts."
+
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 256,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device)
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because Gemma can only handle sequences up to"
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds = self.text_encoder(
+ text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+ )
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ system_prompt: Optional[str] = None,
+ max_sequence_length: int = 256,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ Lumina-T2I, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
+ max_sequence_length (`int`, defaults to `256`):
+ Maximum sequence length to use for the prompt.
+ """
+ if device is None:
+ device = self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if system_prompt is None:
+ system_prompt = self.system_prompt
+ if prompt is not None:
+ prompt = [system_prompt + " " + p for p in prompt]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ batch_size, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
+
+ # Get negative embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
+
+ # Normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
+ batch_size * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ width: Optional[int] = None,
+ height: Optional[int] = None,
+ num_inference_steps: int = 30,
+ guidance_scale: float = 4.0,
+ negative_prompt: Union[str, List[str]] = None,
+ sigmas: List[float] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ system_prompt: Optional[str] = None,
+ cfg_trunc_ratio: float = 1.0,
+ cfg_normalization: bool = True,
+ max_sequence_length: int = 256,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 30):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ system_prompt (`str`, *optional*):
+ The system prompt to use for the image generation.
+ cfg_trunc_ratio (`float`, *optional*, defaults to `1.0`):
+ The ratio of the timestep interval to apply normalization-based guidance scale.
+ cfg_normalization (`bool`, *optional*, defaults to `True`):
+ Whether to apply normalization-based guidance scale.
+ max_sequence_length (`int`, defaults to `256`):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ system_prompt=system_prompt,
+ )
+
+ # 4. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # compute whether apply classifier-free truncation on this timestep
+ do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
+ # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
+ current_timestep = 1 - t / self.scheduler.config.num_train_timesteps
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latents.shape[0])
+
+ noise_pred_cond = self.transformer(
+ hidden_states=latents,
+ timestep=current_timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+
+ # perform normalization-based guidance scale on a truncated timestep interval
+ if self.do_classifier_free_guidance and not do_classifier_free_truncation:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latents,
+ timestep=current_timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ # apply normalization after classifier-free guidance
+ if cfg_normalization:
+ cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True)
+ noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_pred = noise_pred * (cond_norm / noise_norm)
+ else:
+ noise_pred = noise_pred_cond
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ noise_pred = -noise_pred
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+
+class Lumina2Text2ImgPipeline(Lumina2Pipeline):
+ def __init__(
+ self,
+ transformer: Lumina2Transformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: Gemma2PreTrainedModel,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ ):
+ deprecation_message = "`Lumina2Text2ImgPipeline` has been renamed to `Lumina2Pipeline` and will be removed in a future version. Please use `Lumina2Pipeline` instead."
+ deprecate("diffusers.pipelines.lumina2.pipeline_lumina2.Lumina2Text2ImgPipeline", "0.34", deprecation_message)
+ super().__init__(
+ transformer=transformer,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
diff --git a/src/diffusers/pipelines/marigold/__init__.py b/src/diffusers/pipelines/marigold/__init__.py
index b5ae03adfc11..168a8276be4e 100644
--- a/src/diffusers/pipelines/marigold/__init__.py
+++ b/src/diffusers/pipelines/marigold/__init__.py
@@ -23,6 +23,7 @@
else:
_import_structure["marigold_image_processing"] = ["MarigoldImageProcessor"]
_import_structure["pipeline_marigold_depth"] = ["MarigoldDepthOutput", "MarigoldDepthPipeline"]
+ _import_structure["pipeline_marigold_intrinsics"] = ["MarigoldIntrinsicsOutput", "MarigoldIntrinsicsPipeline"]
_import_structure["pipeline_marigold_normals"] = ["MarigoldNormalsOutput", "MarigoldNormalsPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -35,6 +36,7 @@
else:
from .marigold_image_processing import MarigoldImageProcessor
from .pipeline_marigold_depth import MarigoldDepthOutput, MarigoldDepthPipeline
+ from .pipeline_marigold_intrinsics import MarigoldIntrinsicsOutput, MarigoldIntrinsicsPipeline
from .pipeline_marigold_normals import MarigoldNormalsOutput, MarigoldNormalsPipeline
else:
diff --git a/src/diffusers/pipelines/marigold/marigold_image_processing.py b/src/diffusers/pipelines/marigold/marigold_image_processing.py
index 51b9983db6f6..0723014ad37b 100644
--- a/src/diffusers/pipelines/marigold/marigold_image_processing.py
+++ b/src/diffusers/pipelines/marigold/marigold_image_processing.py
@@ -1,4 +1,22 @@
-from typing import List, Optional, Tuple, Union
+# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
+# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# --------------------------------------------------------------------------
+# More information and citation instructions are available on the
+# Marigold project website: https://marigoldcomputervision.github.io
+# --------------------------------------------------------------------------
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
@@ -379,7 +397,7 @@ def visualize_depth(
val_min: float = 0.0,
val_max: float = 1.0,
color_map: str = "Spectral",
- ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
+ ) -> List[PIL.Image.Image]:
"""
Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`.
@@ -391,7 +409,7 @@ def visualize_depth(
color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel
depth prediction into colored representation.
- Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization.
+ Returns: `List[PIL.Image.Image]` with depth maps visualization.
"""
if val_max <= val_min:
raise ValueError(f"Invalid values range: [{val_min}, {val_max}].")
@@ -436,7 +454,7 @@ def export_depth_to_16bit_png(
depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
val_min: float = 0.0,
val_max: float = 1.0,
- ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
+ ) -> List[PIL.Image.Image]:
def export_depth_to_16bit_png_one(img, idx=None):
prefix = "Depth" + (f"[{idx}]" if idx else "")
if not isinstance(img, np.ndarray) and not torch.is_tensor(img):
@@ -478,7 +496,7 @@ def visualize_normals(
flip_x: bool = False,
flip_y: bool = False,
flip_z: bool = False,
- ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
+ ) -> List[PIL.Image.Image]:
"""
Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`.
@@ -492,7 +510,7 @@ def visualize_normals(
flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference.
Default direction is facing the observer.
- Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization.
+ Returns: `List[PIL.Image.Image]` with surface normals visualization.
"""
flip_vec = None
if any((flip_x, flip_y, flip_z)):
@@ -528,6 +546,99 @@ def visualize_normals_one(img, idx=None):
else:
raise ValueError(f"Unexpected input type: {type(normals)}")
+ @staticmethod
+ def visualize_intrinsics(
+ prediction: Union[
+ np.ndarray,
+ torch.Tensor,
+ List[np.ndarray],
+ List[torch.Tensor],
+ ],
+ target_properties: Dict[str, Any],
+ color_map: Union[str, Dict[str, str]] = "binary",
+ ) -> List[Dict[str, PIL.Image.Image]]:
+ """
+ Visualizes intrinsic image decomposition, such as predictions of the `MarigoldIntrinsicsPipeline`.
+
+ Args:
+ prediction (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
+ Intrinsic image decomposition.
+ target_properties (`Dict[str, Any]`):
+ Decomposition properties. Expected entries: `target_names: List[str]` and a dictionary with keys
+ `prediction_space: str`, `sub_target_names: List[Union[str, Null]]` (must have 3 entries, null for
+ missing modalities), `up_to_scale: bool`, one for each target and sub-target.
+ color_map (`Union[str, Dict[str, str]]`, *optional*, defaults to `"Spectral"`):
+ Color map used to convert a single-channel predictions into colored representations. When a dictionary
+ is passed, each modality can be colored with its own color map.
+
+ Returns: `List[Dict[str, PIL.Image.Image]]` with intrinsic image decomposition visualization.
+ """
+ if "target_names" not in target_properties:
+ raise ValueError("Missing `target_names` in target_properties")
+ if not isinstance(color_map, str) and not (
+ isinstance(color_map, dict)
+ and all(isinstance(k, str) and isinstance(v, str) for k, v in color_map.items())
+ ):
+ raise ValueError("`color_map` must be a string or a dictionary of strings")
+ n_targets = len(target_properties["target_names"])
+
+ def visualize_targets_one(images, idx=None):
+ # img: [T, 3, H, W]
+ out = {}
+ for target_name, img in zip(target_properties["target_names"], images):
+ img = img.permute(1, 2, 0) # [H, W, 3]
+ prediction_space = target_properties[target_name].get("prediction_space", "srgb")
+ if prediction_space == "stack":
+ sub_target_names = target_properties[target_name]["sub_target_names"]
+ if len(sub_target_names) != 3 or any(
+ not (isinstance(s, str) or s is None) for s in sub_target_names
+ ):
+ raise ValueError(f"Unexpected target sub-names {sub_target_names} in {target_name}")
+ for i, sub_target_name in enumerate(sub_target_names):
+ if sub_target_name is None:
+ continue
+ sub_img = img[:, :, i]
+ sub_prediction_space = target_properties[sub_target_name].get("prediction_space", "srgb")
+ if sub_prediction_space == "linear":
+ sub_up_to_scale = target_properties[sub_target_name].get("up_to_scale", False)
+ if sub_up_to_scale:
+ sub_img = sub_img / max(sub_img.max().item(), 1e-6)
+ sub_img = sub_img ** (1 / 2.2)
+ cmap_name = (
+ color_map if isinstance(color_map, str) else color_map.get(sub_target_name, "binary")
+ )
+ sub_img = MarigoldImageProcessor.colormap(sub_img, cmap=cmap_name, bytes=True)
+ sub_img = PIL.Image.fromarray(sub_img.cpu().numpy())
+ out[sub_target_name] = sub_img
+ elif prediction_space == "linear":
+ up_to_scale = target_properties[target_name].get("up_to_scale", False)
+ if up_to_scale:
+ img = img / max(img.max().item(), 1e-6)
+ img = img ** (1 / 2.2)
+ elif prediction_space == "srgb":
+ pass
+ img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy()
+ img = PIL.Image.fromarray(img)
+ out[target_name] = img
+ return out
+
+ if prediction is None or isinstance(prediction, list) and any(o is None for o in prediction):
+ raise ValueError("Input prediction is `None`")
+ if isinstance(prediction, (np.ndarray, torch.Tensor)):
+ prediction = MarigoldImageProcessor.expand_tensor_or_array(prediction)
+ if isinstance(prediction, np.ndarray):
+ prediction = MarigoldImageProcessor.numpy_to_pt(prediction) # [N*T,3,H,W]
+ if not (prediction.ndim == 4 and prediction.shape[1] == 3 and prediction.shape[0] % n_targets == 0):
+ raise ValueError(f"Unexpected input shape={prediction.shape}, expecting [N*T,3,H,W].")
+ N_T, _, H, W = prediction.shape
+ N = N_T // n_targets
+ prediction = prediction.reshape(N, n_targets, 3, H, W)
+ return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)]
+ elif isinstance(prediction, list):
+ return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)]
+ else:
+ raise ValueError(f"Unexpected input type: {type(prediction)}")
+
@staticmethod
def visualize_uncertainty(
uncertainty: Union[
@@ -537,9 +648,10 @@ def visualize_uncertainty(
List[torch.Tensor],
],
saturation_percentile=95,
- ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]:
+ ) -> List[PIL.Image.Image]:
"""
- Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`.
+ Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline`, `MarigoldNormalsPipeline`, or
+ `MarigoldIntrinsicsPipeline`.
Args:
uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`):
@@ -547,14 +659,15 @@ def visualize_uncertainty(
saturation_percentile (`int`, *optional*, defaults to `95`):
Specifies the percentile uncertainty value visualized with maximum intensity.
- Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization.
+ Returns: `List[PIL.Image.Image]` with uncertainty visualization.
"""
def visualize_uncertainty_one(img, idx=None):
prefix = "Uncertainty" + (f"[{idx}]" if idx else "")
if img.min() < 0:
- raise ValueError(f"{prefix}: unexected data range, min={img.min()}.")
- img = img.squeeze(0).cpu().numpy()
+ raise ValueError(f"{prefix}: unexpected data range, min={img.min()}.")
+ img = img.permute(1, 2, 0) # [H,W,C]
+ img = img.squeeze(2).cpu().numpy() # [H,W] or [H,W,3]
saturation_value = np.percentile(img, saturation_percentile)
img = np.clip(img * 255 / saturation_value, 0, 255)
img = img.astype(np.uint8)
@@ -566,9 +679,9 @@ def visualize_uncertainty_one(img, idx=None):
if isinstance(uncertainty, (np.ndarray, torch.Tensor)):
uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty)
if isinstance(uncertainty, np.ndarray):
- uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W]
- if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1):
- raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].")
+ uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,C,H,W]
+ if not (uncertainty.ndim == 4 and uncertainty.shape[1] in (1, 3)):
+ raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,C,H,W] with C in (1,3).")
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)]
elif isinstance(uncertainty, list):
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
index a602ba611ea5..da991aefbd4a 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
@@ -1,5 +1,5 @@
-# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
+# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# limitations under the License.
# --------------------------------------------------------------------------
# More information and citation instructions are available on the
-# Marigold project website: https://marigoldmonodepth.github.io
+# Marigold project website: https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
from dataclasses import dataclass
from functools import partial
@@ -37,6 +37,7 @@
)
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -46,6 +47,13 @@
from .marigold_image_processing import MarigoldImageProcessor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -56,7 +64,7 @@
>>> import torch
>>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
-... "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
+... "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
... ).to("cuda")
>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
@@ -78,11 +86,12 @@ class MarigoldDepthOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height
- \times width$, regardless of whether the images were passed as a 4D array or a list.
+ Predicted depth maps with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times
+ width$ for `torch.Tensor` or $numimages \times height \times width \times 1$ for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
- \times 1 \times height \times width$.
+ \times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
+ for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
@@ -174,7 +183,7 @@ def __init__(
default_processing_resolution=default_processing_resolution,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant
@@ -200,6 +209,11 @@ def check_inputs(
output_type: str,
output_uncertainty: bool,
) -> int:
+ actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ if actual_vae_scale_factor != self.vae_scale_factor:
+ raise ValueError(
+ f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})."
+ )
if num_inference_steps is None:
raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
if num_inference_steps < 1:
@@ -312,6 +326,7 @@ def check_inputs(
return num_images
+ @torch.compiler.disable
def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
@@ -362,11 +377,9 @@ def __call__(
same width and height.
num_inference_steps (`int`, *optional*, defaults to `None`):
Number of denoising diffusion steps during inference. The default value `None` results in automatic
- selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
- for Marigold-LCM models.
+ selection.
ensemble_size (`int`, defaults to `1`):
- Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
- faster inference.
+ Number of ensemble predictions. Higher values result in measurable improvements and visual degradation.
processing_resolution (`int`, *optional*, defaults to `None`):
Effective processing resolution. When set to `0`, matches the larger input image dimension. This
produces crisper predictions, but may also lead to the overall loss of global context. The default
@@ -478,9 +491,7 @@ def __call__(
# `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
# into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
# reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
- # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
- # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
- # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
+ # code. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
# dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
# Model invocation: self.vae.encoder.
image_latent, pred_latent = self.prepare_latents(
@@ -517,6 +528,9 @@ def __call__(
noise, t, batch_pred_latent, generator=generator
).prev_sample # [B,4,h,w]
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
pred_latents.append(batch_pred_latent)
pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
@@ -722,6 +736,7 @@ def init_param(depth: torch.Tensor):
param = init_s.cpu().numpy()
else:
raise ValueError("Unrecognized alignment.")
+ param = param.astype(np.float64)
return param
@@ -764,7 +779,7 @@ def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
if regularizer_strength > 0:
prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
- err_near = (0.0 - prediction.min()).abs().item()
+ err_near = prediction.min().abs().item()
err_far = (1.0 - prediction.max()).abs().item()
cost += (err_near + err_far) * regularizer_strength
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py b/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
new file mode 100644
index 000000000000..c809de18f469
--- /dev/null
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
@@ -0,0 +1,721 @@
+# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
+# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# --------------------------------------------------------------------------
+# More information and citation instructions are available on the
+# Marigold project website: https://marigoldcomputervision.github.io
+# --------------------------------------------------------------------------
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import PipelineImageInput
+from ...models import (
+ AutoencoderKL,
+ UNet2DConditionModel,
+)
+from ...schedulers import (
+ DDIMScheduler,
+ LCMScheduler,
+)
+from ...utils import (
+ BaseOutput,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .marigold_image_processing import MarigoldImageProcessor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+Examples:
+```py
+>>> import diffusers
+>>> import torch
+
+>>> pipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained(
+... "prs-eth/marigold-iid-appearance-v1-1", variant="fp16", torch_dtype=torch.float16
+... ).to("cuda")
+
+>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+>>> intrinsics = pipe(image)
+
+>>> vis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties)
+>>> vis[0]["albedo"].save("einstein_albedo.png")
+>>> vis[0]["roughness"].save("einstein_roughness.png")
+>>> vis[0]["metallicity"].save("einstein_metallicity.png")
+```
+```py
+>>> import diffusers
+>>> import torch
+
+>>> pipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained(
+... "prs-eth/marigold-iid-lighting-v1-1", variant="fp16", torch_dtype=torch.float16
+... ).to("cuda")
+
+>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
+>>> intrinsics = pipe(image)
+
+>>> vis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties)
+>>> vis[0]["albedo"].save("einstein_albedo.png")
+>>> vis[0]["shading"].save("einstein_shading.png")
+>>> vis[0]["residual"].save("einstein_residual.png")
+```
+"""
+
+
+@dataclass
+class MarigoldIntrinsicsOutput(BaseOutput):
+ """
+ Output class for Marigold Intrinsic Image Decomposition pipeline.
+
+ Args:
+ prediction (`np.ndarray`, `torch.Tensor`):
+ Predicted image intrinsics with values in the range [0, 1]. The shape is $(numimages * numtargets) \times 3
+ \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times height \times width
+ \times 3$ for `np.ndarray`, where `numtargets` corresponds to the number of predicted target modalities of
+ the intrinsic image decomposition.
+ uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $(numimages *
+ numtargets) \times 3 \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times
+ height \times width \times 3$ for `np.ndarray`.
+ latent (`None`, `torch.Tensor`):
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
+ The shape is $(numimages * numensemble) \times (numtargets * 4) \times latentheight \times latentwidth$.
+ """
+
+ prediction: Union[np.ndarray, torch.Tensor]
+ uncertainty: Union[None, np.ndarray, torch.Tensor]
+ latent: Union[None, torch.Tensor]
+
+
+class MarigoldIntrinsicsPipeline(DiffusionPipeline):
+ """
+ Pipeline for Intrinsic Image Decomposition (IID) using the Marigold method:
+ https://marigoldcomputervision.github.io.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ unet (`UNet2DConditionModel`):
+ Conditional U-Net to denoise the targets latent, conditioned on image latent.
+ vae (`AutoencoderKL`):
+ Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent
+ representations.
+ scheduler (`DDIMScheduler` or `LCMScheduler`):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ text_encoder (`CLIPTextModel`):
+ Text-encoder, for empty text embedding.
+ tokenizer (`CLIPTokenizer`):
+ CLIP tokenizer.
+ prediction_type (`str`, *optional*):
+ Type of predictions made by the model.
+ target_properties (`Dict[str, Any]`, *optional*):
+ Properties of the predicted modalities, such as `target_names`, a `List[str]` used to define the number,
+ order and names of the predicted modalities, and any other metadata that may be required to interpret the
+ predictions.
+ default_denoising_steps (`int`, *optional*):
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
+ quality with the given model. This value must be set in the model config. When the pipeline is called
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
+ default_processing_resolution (`int`, *optional*):
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
+ default value is used. This is required to ensure reasonable results with various model flavors trained
+ with varying optimal processing resolution values.
+ """
+
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ supported_prediction_types = ("intrinsics",)
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ vae: AutoencoderKL,
+ scheduler: Union[DDIMScheduler, LCMScheduler],
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ prediction_type: Optional[str] = None,
+ target_properties: Optional[Dict[str, Any]] = None,
+ default_denoising_steps: Optional[int] = None,
+ default_processing_resolution: Optional[int] = None,
+ ):
+ super().__init__()
+
+ if prediction_type not in self.supported_prediction_types:
+ logger.warning(
+ f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: "
+ f"{self.supported_prediction_types}."
+ )
+
+ self.register_modules(
+ unet=unet,
+ vae=vae,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+ self.register_to_config(
+ prediction_type=prediction_type,
+ target_properties=target_properties,
+ default_denoising_steps=default_denoising_steps,
+ default_processing_resolution=default_processing_resolution,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+
+ self.target_properties = target_properties
+ self.default_denoising_steps = default_denoising_steps
+ self.default_processing_resolution = default_processing_resolution
+
+ self.empty_text_embedding = None
+
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ @property
+ def n_targets(self):
+ return self.unet.config.out_channels // self.vae.config.latent_channels
+
+ def check_inputs(
+ self,
+ image: PipelineImageInput,
+ num_inference_steps: int,
+ ensemble_size: int,
+ processing_resolution: int,
+ resample_method_input: str,
+ resample_method_output: str,
+ batch_size: int,
+ ensembling_kwargs: Optional[Dict[str, Any]],
+ latents: Optional[torch.Tensor],
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
+ output_type: str,
+ output_uncertainty: bool,
+ ) -> int:
+ actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ if actual_vae_scale_factor != self.vae_scale_factor:
+ raise ValueError(
+ f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})."
+ )
+ if num_inference_steps is None:
+ raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
+ if num_inference_steps < 1:
+ raise ValueError("`num_inference_steps` must be positive.")
+ if ensemble_size < 1:
+ raise ValueError("`ensemble_size` must be positive.")
+ if ensemble_size == 2:
+ logger.warning(
+ "`ensemble_size` == 2 results are similar to no ensembling (1); "
+ "consider increasing the value to at least 3."
+ )
+ if ensemble_size == 1 and output_uncertainty:
+ raise ValueError(
+ "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
+ "greater than 1."
+ )
+ if processing_resolution is None:
+ raise ValueError(
+ "`processing_resolution` is not specified and could not be resolved from the model config."
+ )
+ if processing_resolution < 0:
+ raise ValueError(
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
+ "downsampled processing."
+ )
+ if processing_resolution % self.vae_scale_factor != 0:
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
+ raise ValueError(
+ "`resample_method_input` takes string values compatible with PIL library: "
+ "nearest, nearest-exact, bilinear, bicubic, area."
+ )
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
+ raise ValueError(
+ "`resample_method_output` takes string values compatible with PIL library: "
+ "nearest, nearest-exact, bilinear, bicubic, area."
+ )
+ if batch_size < 1:
+ raise ValueError("`batch_size` must be positive.")
+ if output_type not in ["pt", "np"]:
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
+ if latents is not None and generator is not None:
+ raise ValueError("`latents` and `generator` cannot be used together.")
+ if ensembling_kwargs is not None:
+ if not isinstance(ensembling_kwargs, dict):
+ raise ValueError("`ensembling_kwargs` must be a dictionary.")
+ if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("median", "mean"):
+ raise ValueError("`ensembling_kwargs['reduction']` can be either `'median'` or `'mean'`.")
+
+ # image checks
+ num_images = 0
+ W, H = None, None
+ if not isinstance(image, list):
+ image = [image]
+ for i, img in enumerate(image):
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
+ if img.ndim not in (2, 3, 4):
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
+ H_i, W_i = img.shape[-2:]
+ N_i = 1
+ if img.ndim == 4:
+ N_i = img.shape[0]
+ elif isinstance(img, Image.Image):
+ W_i, H_i = img.size
+ N_i = 1
+ else:
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
+ if W is None:
+ W, H = W_i, H_i
+ elif (W, H) != (W_i, H_i):
+ raise ValueError(
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
+ )
+ num_images += N_i
+
+ # latents checks
+ if latents is not None:
+ if not torch.is_tensor(latents):
+ raise ValueError("`latents` must be a torch.Tensor.")
+ if latents.dim() != 4:
+ raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
+
+ if processing_resolution > 0:
+ max_orig = max(H, W)
+ new_H = H * processing_resolution // max_orig
+ new_W = W * processing_resolution // max_orig
+ if new_H == 0 or new_W == 0:
+ raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
+ W, H = new_W, new_H
+ w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
+ h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
+ shape_expected = (num_images * ensemble_size, self.unet.config.out_channels, h, w)
+
+ if latents.shape != shape_expected:
+ raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
+
+ # generator checks
+ if generator is not None:
+ if isinstance(generator, list):
+ if len(generator) != num_images * ensemble_size:
+ raise ValueError(
+ "The number of generators must match the total number of ensemble members for all input images."
+ )
+ if not all(g.device.type == generator[0].device.type for g in generator):
+ raise ValueError("`generator` device placement is not consistent in the list.")
+ elif not isinstance(generator, torch.Generator):
+ raise ValueError(f"Unsupported generator type: {type(generator)}.")
+
+ return num_images
+
+ @torch.compiler.disable
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ progress_bar_config = dict(**self._progress_bar_config)
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
+ if iterable is not None:
+ return tqdm(iterable, **progress_bar_config)
+ elif total is not None:
+ return tqdm(total=total, **progress_bar_config)
+ else:
+ raise ValueError("Either `total` or `iterable` has to be defined.")
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ num_inference_steps: Optional[int] = None,
+ ensemble_size: int = 1,
+ processing_resolution: Optional[int] = None,
+ match_input_resolution: bool = True,
+ resample_method_input: str = "bilinear",
+ resample_method_output: str = "bilinear",
+ batch_size: int = 1,
+ ensembling_kwargs: Optional[Dict[str, Any]] = None,
+ latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: str = "np",
+ output_uncertainty: bool = False,
+ output_latent: bool = False,
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline.
+
+ Args:
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
+ `List[torch.Tensor]`: An input image or images used as an input for the intrinsic decomposition task.
+ For arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is
+ possible by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
+ same width and height.
+ num_inference_steps (`int`, *optional*, defaults to `None`):
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
+ selection.
+ ensemble_size (`int`, defaults to `1`):
+ Number of ensemble predictions. Higher values result in measurable improvements and visual degradation.
+ processing_resolution (`int`, *optional*, defaults to `None`):
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
+ value `None` resolves to the optimal value from the model config.
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
+ side of the output will equal to `processing_resolution`.
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
+ batch_size (`int`, *optional*, defaults to `1`):
+ Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
+ ensembling_kwargs (`dict`, *optional*, defaults to `None`)
+ Extra dictionary with arguments for precise ensembling control. The following options are available:
+ - reduction (`str`, *optional*, defaults to `"median"`): Defines the ensembling function applied in
+ every pixel location, can be either `"median"` or `"mean"`.
+ latents (`torch.Tensor`, *optional*, defaults to `None`):
+ Latent noise tensors to replace the random initialization. These can be taken from the previous
+ function call's output.
+ generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
+ Random number generator object to ensure reproducibility.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
+ values are: `"np"` (numpy array) or `"pt"` (torch tensor).
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
+ When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
+ the `ensemble_size` argument is set to a value above 2.
+ output_latent (`bool`, *optional*, defaults to `False`):
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
+ `latents` argument.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.marigold.MarigoldIntrinsicsOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.marigold.MarigoldIntrinsicsOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.marigold.MarigoldIntrinsicsOutput`] is returned, otherwise a
+ `tuple` is returned where the first element is the prediction, the second element is the uncertainty
+ (or `None`), and the third is the latent (or `None`).
+ """
+
+ # 0. Resolving variables.
+ device = self._execution_device
+ dtype = self.dtype
+
+ # Model-specific optimal default values leading to fast and reasonable results.
+ if num_inference_steps is None:
+ num_inference_steps = self.default_denoising_steps
+ if processing_resolution is None:
+ processing_resolution = self.default_processing_resolution
+
+ # 1. Check inputs.
+ num_images = self.check_inputs(
+ image,
+ num_inference_steps,
+ ensemble_size,
+ processing_resolution,
+ resample_method_input,
+ resample_method_output,
+ batch_size,
+ ensembling_kwargs,
+ latents,
+ generator,
+ output_type,
+ output_uncertainty,
+ )
+
+ # 2. Prepare empty text conditioning.
+ # Model invocation: self.tokenizer, self.text_encoder.
+ if self.empty_text_embedding is None:
+ prompt = ""
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="do_not_pad",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids.to(device)
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
+
+ # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
+ # resolution can lead to loss of either fine details or global context in the output predictions.
+ image, padding, original_resolution = self.image_processor.preprocess(
+ image, processing_resolution, resample_method_input, device, dtype
+ ) # [N,3,PPH,PPW]
+
+ # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
+ # `pred_latent` variable. The variable `image_latent` contains each input image encoded into latent space and
+ # replicated `E` times. The variable `pred_latent` contains latents initialization, where the latent space is
+ # replicated `T` times relative to the single latent space of `image_latent`, where `T` is the number of the
+ # predicted targets. The latents can be either generated (see `generator` to ensure reproducibility), or passed
+ # explicitly via the `latents` argument. The latter can be set outside the pipeline code. This behavior can be
+ # achieved by setting the `output_latent` argument to `True`. The latent space dimensions are `(h, w)`. Encoding
+ # into latent space happens in batches of size `batch_size`.
+ # Model invocation: self.vae.encoder.
+ image_latent, pred_latent = self.prepare_latents(
+ image, latents, generator, ensemble_size, batch_size
+ ) # [N*E,4,h,w], [N*E,T*4,h,w]
+
+ del image
+
+ batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat(
+ batch_size, 1, 1
+ ) # [B,1024,2]
+
+ # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`.
+ # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and
+ # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by
+ # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded
+ # model.
+ # Model invocation: self.unet.
+ pred_latents = []
+
+ for i in self.progress_bar(
+ range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..."
+ ):
+ batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w]
+ batch_pred_latent = pred_latent[i : i + batch_size] # [B,T*4,h,w]
+ effective_batch_size = batch_image_latent.shape[0]
+ text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024]
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."):
+ batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,(1+T)*4,h,w]
+ noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,T*4,h,w]
+ batch_pred_latent = self.scheduler.step(
+ noise, t, batch_pred_latent, generator=generator
+ ).prev_sample # [B,T*4,h,w]
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ pred_latents.append(batch_pred_latent)
+
+ pred_latent = torch.cat(pred_latents, dim=0) # [N*E,T*4,h,w]
+
+ del (
+ pred_latents,
+ image_latent,
+ batch_empty_text_embedding,
+ batch_image_latent,
+ batch_pred_latent,
+ text,
+ batch_latent,
+ noise,
+ )
+
+ # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
+ # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
+ # Model invocation: self.vae.decoder.
+ pred_latent_for_decoding = pred_latent.reshape(
+ num_images * ensemble_size * self.n_targets, self.vae.config.latent_channels, *pred_latent.shape[2:]
+ ) # [N*E*T,4,PPH,PPW]
+ prediction = torch.cat(
+ [
+ self.decode_prediction(pred_latent_for_decoding[i : i + batch_size])
+ for i in range(0, pred_latent_for_decoding.shape[0], batch_size)
+ ],
+ dim=0,
+ ) # [N*E*T,3,PPH,PPW]
+
+ del pred_latent_for_decoding
+ if not output_latent:
+ pred_latent = None
+
+ # 7. Remove padding. The output shape is (PH, PW).
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E*T,3,PH,PW]
+
+ # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N*T`
+ # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape
+ # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for
+ # each group independently, it stacks them respectively into batches of `N*T` almost final predictions and
+ # uncertainty maps.
+ uncertainty = None
+ if ensemble_size > 1:
+ prediction = prediction.reshape(
+ num_images, ensemble_size, self.n_targets, *prediction.shape[1:]
+ ) # [N,E,T,3,PH,PW]
+ prediction = [
+ self.ensemble_intrinsics(prediction[i], output_uncertainty, **(ensembling_kwargs or {}))
+ for i in range(num_images)
+ ] # [ [[T,3,PH,PW], [T,3,PH,PW]], ... ]
+ prediction, uncertainty = zip(*prediction) # [[T,3,PH,PW], ... ], [[T,3,PH,PW], ... ]
+ prediction = torch.cat(prediction, dim=0) # [N*T,3,PH,PW]
+ if output_uncertainty:
+ uncertainty = torch.cat(uncertainty, dim=0) # [N*T,3,PH,PW]
+ else:
+ uncertainty = None
+
+ # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the
+ # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled.
+ # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by
+ # setting the `resample_method_output` parameter (e.g., to `"nearest"`).
+ if match_input_resolution:
+ prediction = self.image_processor.resize_antialias(
+ prediction, original_resolution, resample_method_output, is_aa=False
+ ) # [N*T,3,H,W]
+ if uncertainty is not None and output_uncertainty:
+ uncertainty = self.image_processor.resize_antialias(
+ uncertainty, original_resolution, resample_method_output, is_aa=False
+ ) # [N*T,1,H,W]
+
+ # 10. Prepare the final outputs.
+ if output_type == "np":
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N*T,H,W,3]
+ if uncertainty is not None and output_uncertainty:
+ uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N*T,H,W,3]
+
+ # 11. Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (prediction, uncertainty, pred_latent)
+
+ return MarigoldIntrinsicsOutput(
+ prediction=prediction,
+ uncertainty=uncertainty,
+ latent=pred_latent,
+ )
+
+ def prepare_latents(
+ self,
+ image: torch.Tensor,
+ latents: Optional[torch.Tensor],
+ generator: Optional[torch.Generator],
+ ensemble_size: int,
+ batch_size: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ def retrieve_latents(encoder_output):
+ if hasattr(encoder_output, "latent_dist"):
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+ image_latent = torch.cat(
+ [
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
+ for i in range(0, image.shape[0], batch_size)
+ ],
+ dim=0,
+ ) # [N,4,h,w]
+ image_latent = image_latent * self.vae.config.scaling_factor
+ image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
+ N_E, C, H, W = image_latent.shape
+
+ pred_latent = latents
+ if pred_latent is None:
+ pred_latent = randn_tensor(
+ (N_E, self.n_targets * C, H, W),
+ generator=generator,
+ device=image_latent.device,
+ dtype=image_latent.dtype,
+ ) # [N*E,T*4,h,w]
+
+ return image_latent, pred_latent
+
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
+ raise ValueError(
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
+ )
+
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
+
+ prediction = torch.clip(prediction, -1.0, 1.0) # [B,3,H,W]
+ prediction = (prediction + 1.0) / 2.0
+
+ return prediction # [B,3,H,W]
+
+ @staticmethod
+ def ensemble_intrinsics(
+ targets: torch.Tensor,
+ output_uncertainty: bool = False,
+ reduction: str = "median",
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Ensembles the intrinsic decomposition represented by the `targets` tensor with expected shape `(B, T, 3, H,
+ W)`, where B is the number of ensemble members for a given prediction of size `(H x W)`, and T is the number of
+ predicted targets.
+
+ Args:
+ targets (`torch.Tensor`):
+ Input ensemble of intrinsic image decomposition maps.
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
+ Whether to output uncertainty map.
+ reduction (`str`, *optional*, defaults to `"mean"`):
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"median"` and
+ `"mean"`.
+
+ Returns:
+ A tensor of aligned and ensembled intrinsic decomposition maps with shape `(T, 3, H, W)` and optionally a
+ tensor of uncertainties of shape `(T, 3, H, W)`.
+ """
+ if targets.dim() != 5 or targets.shape[2] != 3:
+ raise ValueError(f"Expecting 4D tensor of shape [B,T,3,H,W]; got {targets.shape}.")
+ if reduction not in ("median", "mean"):
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
+
+ B, T, _, H, W = targets.shape
+ uncertainty = None
+ if reduction == "mean":
+ prediction = torch.mean(targets, dim=0) # [T,3,H,W]
+ if output_uncertainty:
+ uncertainty = torch.std(targets, dim=0) # [T,3,H,W]
+ elif reduction == "median":
+ prediction = torch.median(targets, dim=0, keepdim=True).values # [1,T,3,H,W]
+ if output_uncertainty:
+ uncertainty = torch.abs(targets - prediction) # [B,T,3,H,W]
+ uncertainty = torch.median(uncertainty, dim=0).values # [T,3,H,W]
+ prediction = prediction.squeeze(0) # [T,3,H,W]
+ else:
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
+ return prediction, uncertainty
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
index aa9ad36ffc35..192ed590a489 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
@@ -1,5 +1,5 @@
-# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
+# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# limitations under the License.
# --------------------------------------------------------------------------
# More information and citation instructions are available on the
-# Marigold project website: https://marigoldmonodepth.github.io
+# Marigold project website: https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -36,6 +36,7 @@
)
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -44,6 +45,13 @@
from .marigold_image_processing import MarigoldImageProcessor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -54,7 +62,7 @@
>>> import torch
>>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
-... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
+... "prs-eth/marigold-normals-v1-1", variant="fp16", torch_dtype=torch.float16
... ).to("cuda")
>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
@@ -73,11 +81,12 @@ class MarigoldNormalsOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
- \times width$, regardless of whether the images were passed as a 4D array or a list.
+ Predicted normals with values in the range [-1, 1]. The shape is $numimages \times 3 \times height \times
+ width$ for `torch.Tensor` or $numimages \times height \times width \times 3$ for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
- \times 1 \times height \times width$.
+ \times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
+ for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
@@ -156,12 +165,13 @@ def __init__(
tokenizer=tokenizer,
)
self.register_to_config(
+ prediction_type=prediction_type,
use_full_z_range=use_full_z_range,
default_denoising_steps=default_denoising_steps,
default_processing_resolution=default_processing_resolution,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.use_full_z_range = use_full_z_range
self.default_denoising_steps = default_denoising_steps
@@ -186,6 +196,11 @@ def check_inputs(
output_type: str,
output_uncertainty: bool,
) -> int:
+ actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ if actual_vae_scale_factor != self.vae_scale_factor:
+ raise ValueError(
+ f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})."
+ )
if num_inference_steps is None:
raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
if num_inference_steps < 1:
@@ -296,6 +311,7 @@ def check_inputs(
return num_images
+ @torch.compiler.disable
def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
@@ -346,11 +362,9 @@ def __call__(
same width and height.
num_inference_steps (`int`, *optional*, defaults to `None`):
Number of denoising diffusion steps during inference. The default value `None` results in automatic
- selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
- for Marigold-LCM models.
+ selection.
ensemble_size (`int`, defaults to `1`):
- Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
- faster inference.
+ Number of ensemble predictions. Higher values result in measurable improvements and visual degradation.
processing_resolution (`int`, *optional*, defaults to `None`):
Effective processing resolution. When set to `0`, matches the larger input image dimension. This
produces crisper predictions, but may also lead to the overall loss of global context. The default
@@ -386,7 +400,7 @@ def __call__(
within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
`latents` argument.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
+ Whether or not to return a [`~pipelines.marigold.MarigoldNormalsOutput`] instead of a plain tuple.
Examples:
@@ -454,9 +468,7 @@ def __call__(
# `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
# into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
# reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
- # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
- # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
- # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
+ # code. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
# dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
# Model invocation: self.vae.encoder.
image_latent, pred_latent = self.prepare_latents(
@@ -493,6 +505,9 @@ def __call__(
noise, t, batch_pred_latent, generator=generator
).prev_sample # [B,4,h,w]
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
pred_latents.append(batch_pred_latent)
pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
diff --git a/src/diffusers/pipelines/mochi/__init__.py b/src/diffusers/pipelines/mochi/__init__.py
new file mode 100644
index 000000000000..a8fd4da9fd36
--- /dev/null
+++ b/src/diffusers/pipelines/mochi/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_mochi"] = ["MochiPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_mochi import MochiPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py
new file mode 100644
index 000000000000..d1f88b02c5cc
--- /dev/null
+++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py
@@ -0,0 +1,744 @@
+# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import Mochi1LoraLoaderMixin
+from ...models import AutoencoderKLMochi, MochiTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import MochiPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import MochiPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
+ >>> pipe.enable_model_cpu_offload()
+ >>> pipe.enable_vae_tiling()
+ >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
+ >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
+ >>> export_to_video(frames, "mochi.mp4")
+ ```
+"""
+
+
+# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
+def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
+ if linear_steps is None:
+ linear_steps = num_steps // 2
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
+ quadratic_steps = num_steps - linear_steps
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
+ const = quadratic_coef * (linear_steps**2)
+ quadratic_sigma_schedule = [
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
+ ]
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
+ return sigma_schedule
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
+ r"""
+ The mochi pipeline for text-to-video generation.
+
+ Reference: https://github.com/genmoai/models
+
+ Args:
+ transformer ([`MochiTransformer3DModel`]):
+ Conditional Transformer architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLMochi`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLMochi,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: MochiTransformer3DModel,
+ force_zeros_for_empty_prompt: bool = False,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ # TODO: determine these scaling factors from model parameters
+ self.vae_spatial_scale_factor = 8
+ self.vae_temporal_scale_factor = 6
+ self.patch_size = 2
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
+ )
+ self.default_height = 480
+ self.default_width = 848
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ # The original Mochi implementation zeros out empty negative prompts
+ # but this can lead to overflow when placing the entire pipeline under the autocast context
+ # adding this here so that we can enable zeroing prompts if necessary
+ if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
+ text_input_ids = torch.zeros_like(text_input_ids, device=device)
+ prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = height // self.vae_spatial_scale_factor
+ width = width // self.vae_spatial_scale_factor
+ num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
+
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
+ latents = latents.to(dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: int = 19,
+ num_inference_steps: int = 64,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.5,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to `self.default_height`):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to `self.default_width`):
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
+ num_frames (`int`, defaults to `19`):
+ The number of video frames to generate
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, defaults to `4.5`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `256`):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
+ is returned where the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.default_height
+ width = width or self.default_width
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Prepare text embeddings
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 5. Prepare timestep
+ # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
+ threshold_noise = 0.025
+ sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
+ sigmas = np.array(sigmas)
+
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
+ # to make sure we're using the correct non-reversed timestep values.
+ self._current_timestep = 1000 - t
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ # Mochi CFG + Sampling runs in FP32
+ noise_pred = noise_pred.to(torch.float32)
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
+ latents = latents.to(latents_dtype)
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ video = latents
+ else:
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return MochiPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/mochi/pipeline_output.py b/src/diffusers/pipelines/mochi/pipeline_output.py
new file mode 100644
index 000000000000..d15827bc0084
--- /dev/null
+++ b/src/diffusers/pipelines/mochi/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class MochiPipelineOutput(BaseOutput):
+ r"""
+ Output class for Mochi pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
index 728635da6d4d..73837af7d429 100644
--- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
+++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
@@ -42,8 +42,20 @@
if is_librosa_available():
import librosa
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -111,7 +123,7 @@ def __init__(
scheduler=scheduler,
vocoder=vocoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def _encode_prompt(
self,
@@ -603,6 +615,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.maybe_free_model_hooks()
# 8. Post-processing
diff --git a/src/diffusers/pipelines/omnigen/__init__.py b/src/diffusers/pipelines/omnigen/__init__.py
new file mode 100644
index 000000000000..557e7c08dc22
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_omnigen"] = ["OmniGenPipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_omnigen import OmniGenPipeline
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
new file mode 100644
index 000000000000..5fe5be3b26d2
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -0,0 +1,512 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import LlamaTokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import OmniGenTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from .processor_omnigen import OmniGenMultiModalProcessor
+
+
+if is_torch_xla_available():
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import OmniGenPipeline
+
+ >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class OmniGenPipeline(
+ DiffusionPipeline,
+):
+ r"""
+ The OmniGen pipeline for multimodal-to-image generation.
+
+ Reference: https://arxiv.org/pdf/2409.11340
+
+ Args:
+ transformer ([`OmniGenTransformer2DModel`]):
+ Autoregressive Transformer architecture for OmniGen.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ tokenizer (`LlamaTokenizer`):
+ Text tokenizer of class.
+ [LlamaTokenizer](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaTokenizer).
+ """
+
+ model_cpu_offload_seq = "transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents"]
+
+ def __init__(
+ self,
+ transformer: OmniGenTransformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ tokenizer: LlamaTokenizer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) is not None else 8
+ )
+ # OmniGen latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.multimodal_processor = OmniGenMultiModalProcessor(tokenizer, max_image_size=1024)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 120000
+ )
+ self.default_sample_size = 128
+
+ def encode_input_images(
+ self,
+ input_pixel_values: List[torch.Tensor],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ get the continue embedding of input images by VAE
+
+ Args:
+ input_pixel_values: normlized pixel of input images
+ device:
+ Returns: torch.Tensor
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.vae.dtype
+
+ input_img_latents = []
+ for img in input_pixel_values:
+ img = self.vae.encode(img.to(device, dtype)).latent_dist.sample().mul_(self.vae.config.scaling_factor)
+ input_img_latents.append(img)
+ return input_img_latents
+
+ def check_inputs(
+ self,
+ prompt,
+ input_images,
+ height,
+ width,
+ use_input_image_size_as_output,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if input_images is not None:
+ if len(input_images) != len(prompt):
+ raise ValueError(
+ f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}."
+ )
+ for i in range(len(input_images)):
+ if input_images[i] is not None:
+ if not all(f" <|image_{k + 1}|>" in prompt[i] for k in range(len(input_images[i]))):
+ raise ValueError(
+ f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`"
+ )
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if use_input_image_size_as_output:
+ if input_images is None or input_images[0] is None:
+ raise ValueError(
+ "`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ input_images: Union[PipelineImageInput, List[PipelineImageInput]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ max_input_image_size: int = 1024,
+ timesteps: List[int] = None,
+ guidance_scale: float = 2.5,
+ img_guidance_scale: float = 1.6,
+ use_input_image_size_as_output: bool = False,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If the input includes images, need to add
+ placeholders ` <|image_i|>` in the prompt to indicate the position of the i-th images.
+ input_images (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
+ The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ max_input_image_size (`int`, *optional*, defaults to 1024):
+ the maximum size of input image, which will be used to crop the input image to the maximum size
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 2.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
+ use_input_image_size_as_output (bool, defaults to False):
+ whether to use the input image size as the output image size, which can be used for single-image input,
+ e.g., image editing task
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ num_cfg = 2 if input_images is not None else 1
+ use_img_cfg = True if input_images is not None else False
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ input_images = [input_images]
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ input_images,
+ height,
+ width,
+ use_input_image_size_as_output,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # 2. Define call parameters
+ batch_size = len(prompt)
+ device = self._execution_device
+
+ # 3. process multi-modal instructions
+ if max_input_image_size != self.multimodal_processor.max_image_size:
+ self.multimodal_processor.reset_max_image_size(max_image_size=max_input_image_size)
+ processed_data = self.multimodal_processor(
+ prompt,
+ input_images,
+ height=height,
+ width=width,
+ use_img_cfg=use_img_cfg,
+ use_input_image_size_as_output=use_input_image_size_as_output,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ processed_data["input_ids"] = processed_data["input_ids"].to(device)
+ processed_data["attention_mask"] = processed_data["attention_mask"].to(device)
+ processed_data["position_ids"] = processed_data["position_ids"].to(device)
+
+ # 4. Encode input images
+ input_img_latents = self.encode_input_images(processed_data["input_pixel_values"], device=device)
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps]
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latents
+ transformer_dtype = self.transformer.dtype
+ if use_input_image_size_as_output:
+ height, width = processed_data["input_pixel_values"][0].shape[-2:]
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * (num_cfg + 1))
+ latent_model_input = latent_model_input.to(transformer_dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ input_ids=processed_data["input_ids"],
+ input_img_latents=input_img_latents,
+ input_image_sizes=processed_data["input_image_sizes"],
+ attention_mask=processed_data["attention_mask"],
+ position_ids=processed_data["position_ids"],
+ return_dict=False,
+ )[0]
+
+ if num_cfg == 2:
+ cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0)
+ noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond)
+ else:
+ cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0)
+ noise_pred = uncond + guidance_scale * (cond - uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+
+ progress_bar.update()
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents = latents / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
new file mode 100644
index 000000000000..75d272ac5140
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -0,0 +1,327 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Dict, List
+
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms
+
+
+def crop_image(pil_image, max_image_size):
+ """
+ Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and
+ width are multiples of 16.
+ """
+ while min(*pil_image.size) >= 2 * max_image_size:
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
+
+ if max(*pil_image.size) > max_image_size:
+ scale = max_image_size / max(*pil_image.size)
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ if min(*pil_image.size) < 16:
+ scale = 16 / min(*pil_image.size)
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ arr = np.array(pil_image)
+ crop_y1 = (arr.shape[0] % 16) // 2
+ crop_y2 = arr.shape[0] % 16 - crop_y1
+
+ crop_x1 = (arr.shape[1] % 16) // 2
+ crop_x2 = arr.shape[1] % 16 - crop_x1
+
+ arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2]
+ return Image.fromarray(arr)
+
+
+class OmniGenMultiModalProcessor:
+ def __init__(self, text_tokenizer, max_image_size: int = 1024):
+ self.text_tokenizer = text_tokenizer
+ self.max_image_size = max_image_size
+
+ self.image_transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ self.collator = OmniGenCollator()
+
+ def reset_max_image_size(self, max_image_size):
+ self.max_image_size = max_image_size
+ self.image_transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ def process_image(self, image):
+ if isinstance(image, str):
+ image = Image.open(image).convert("RGB")
+ return self.image_transform(image)
+
+ def process_multi_modal_prompt(self, text, input_images):
+ text = self.add_prefix_instruction(text)
+ if input_images is None or len(input_images) == 0:
+ model_inputs = self.text_tokenizer(text)
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
+
+ pattern = r"<\|image_\d+\|>"
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
+
+ for i in range(1, len(prompt_chunks)):
+ if prompt_chunks[i][0] == 1:
+ prompt_chunks[i] = prompt_chunks[i][1:]
+
+ image_tags = re.findall(pattern, text)
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
+
+ unique_image_ids = sorted(set(image_ids))
+ assert unique_image_ids == list(
+ range(1, len(unique_image_ids) + 1)
+ ), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ # total images must be the same as the number of image tags
+ assert (
+ len(unique_image_ids) == len(input_images)
+ ), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+
+ input_images = [input_images[x - 1] for x in image_ids]
+
+ all_input_ids = []
+ img_inx = []
+ for i in range(len(prompt_chunks)):
+ all_input_ids.extend(prompt_chunks[i])
+ if i != len(prompt_chunks) - 1:
+ start_inx = len(all_input_ids)
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
+ img_inx.append([start_inx, start_inx + size])
+ all_input_ids.extend([0] * size)
+
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
+
+ def add_prefix_instruction(self, prompt):
+ user_prompt = "<|user|>\n"
+ generation_prompt = "Generate an image according to the following instructions\n"
+ assistant_prompt = "<|assistant|>\n<|diffusion|>"
+ prompt_suffix = "<|end|>\n"
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
+ return prompt
+
+ def __call__(
+ self,
+ instructions: List[str],
+ input_images: List[List[str]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
+ use_img_cfg: bool = True,
+ separate_cfg_input: bool = False,
+ use_input_image_size_as_output: bool = False,
+ num_images_per_prompt: int = 1,
+ ) -> Dict:
+ if isinstance(instructions, str):
+ instructions = [instructions]
+ input_images = [input_images]
+
+ input_data = []
+ for i in range(len(instructions)):
+ cur_instruction = instructions[i]
+ cur_input_images = None if input_images is None else input_images[i]
+ if cur_input_images is not None and len(cur_input_images) > 0:
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
+ else:
+ cur_input_images = None
+ assert " <|image_1|>" not in cur_instruction
+
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
+
+ neg_mllm_input, img_cfg_mllm_input = None, None
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
+ if use_img_cfg:
+ if cur_input_images is not None and len(cur_input_images) >= 1:
+ img_cfg_prompt = [f" <|image_{i + 1}|>" for i in range(len(cur_input_images))]
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
+ else:
+ img_cfg_mllm_input = neg_mllm_input
+
+ for _ in range(num_images_per_prompt):
+ if use_input_image_size_as_output:
+ input_data.append(
+ (
+ mllm_input,
+ neg_mllm_input,
+ img_cfg_mllm_input,
+ [mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
+ )
+ )
+ else:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
+
+ return self.collator(input_data)
+
+
+class OmniGenCollator:
+ def __init__(self, pad_token_id=2, hidden_size=3072):
+ self.pad_token_id = pad_token_id
+ self.hidden_size = hidden_size
+
+ def create_position(self, attention_mask, num_tokens_for_output_images):
+ position_ids = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ temp_position = [0] * (text_length - temp_l) + list(
+ range(temp_l + img_length + 1)
+ ) # we add a time embedding into the sequence, so add one more token
+ position_ids.append(temp_position)
+ return torch.LongTensor(position_ids)
+
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
+ """
+ OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within
+ each image sequence References: [OmniGen](https://arxiv.org/pdf/2409.11340)
+ """
+ extended_mask = []
+ padding_images = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
+ inx = 0
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ pad_l = text_length - temp_l
+
+ temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
+
+ image_mask = torch.zeros(size=(temp_l + 1, img_length))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
+
+ image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
+
+ if pad_l > 0:
+ pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
+
+ pad_mask = torch.ones(size=(pad_l, seq_len))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
+
+ true_img_length = num_tokens_for_output_images[inx]
+ pad_img_length = img_length - true_img_length
+ if pad_img_length > 0:
+ temp_mask[:, -pad_img_length:] = 0
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
+ else:
+ temp_padding_imgs = None
+
+ extended_mask.append(temp_mask.unsqueeze(0))
+ padding_images.append(temp_padding_imgs)
+ inx += 1
+ return torch.cat(extended_mask, dim=0), padding_images
+
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
+ for b_inx in image_sizes.keys():
+ for start_inx, end_inx in image_sizes[b_inx]:
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
+
+ return attention_mask
+
+ def pad_input_ids(self, input_ids, image_sizes):
+ max_l = max([len(x) for x in input_ids])
+ padded_ids = []
+ attention_mask = []
+
+ for i in range(len(input_ids)):
+ temp_ids = input_ids[i]
+ temp_l = len(temp_ids)
+ pad_l = max_l - temp_l
+ if pad_l == 0:
+ attention_mask.append([1] * max_l)
+ padded_ids.append(temp_ids)
+ else:
+ attention_mask.append([0] * pad_l + [1] * temp_l)
+ padded_ids.append([self.pad_token_id] * pad_l + temp_ids)
+
+ if i in image_sizes:
+ new_inx = []
+ for old_inx in image_sizes[i]:
+ new_inx.append([x + pad_l for x in old_inx])
+ image_sizes[i] = new_inx
+
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
+
+ def process_mllm_input(self, mllm_inputs, target_img_size):
+ num_tokens_for_output_images = []
+ for img_size in target_img_size:
+ num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
+
+ pixel_values, image_sizes = [], {}
+ b_inx = 0
+ for x in mllm_inputs:
+ if x["pixel_values"] is not None:
+ pixel_values.extend(x["pixel_values"])
+ for size in x["image_sizes"]:
+ if b_inx not in image_sizes:
+ image_sizes[b_inx] = [size]
+ else:
+ image_sizes[b_inx].append(size)
+ b_inx += 1
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
+
+ input_ids = [x["input_ids"] for x in mllm_inputs]
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
+
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
+
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+ if img_cfg_mllm_input[0] is not None:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
+ target_img_size = target_img_size + target_img_size + target_img_size
+ else:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
+ target_img_size = target_img_size + target_img_size
+
+ (
+ all_padded_input_ids,
+ all_position_ids,
+ all_attention_mask,
+ all_padding_images,
+ all_pixel_values,
+ all_image_sizes,
+ ) = self.process_mllm_input(mllm_inputs, target_img_size)
+
+ data = {
+ "input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ }
+ return data
diff --git a/src/diffusers/pipelines/onnx_utils.py b/src/diffusers/pipelines/onnx_utils.py
index 11f2241c64c8..0e12340f6895 100644
--- a/src/diffusers/pipelines/onnx_utils.py
+++ b/src/diffusers/pipelines/onnx_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -61,7 +61,7 @@ def __call__(self, **kwargs):
return self.model.run(None, inputs)
@staticmethod
- def load_model(path: Union[str, Path], provider=None, sess_options=None):
+ def load_model(path: Union[str, Path], provider=None, sess_options=None, provider_options=None):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
@@ -75,7 +75,9 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None):
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
- return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
+ return ort.InferenceSession(
+ path, providers=[provider], sess_options=sess_options, provider_options=provider_options
+ )
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
"""
diff --git a/src/diffusers/pipelines/pag/__init__.py b/src/diffusers/pipelines/pag/__init__.py
index 6a6723b58ca9..176efe3adef6 100644
--- a/src/diffusers/pipelines/pag/__init__.py
+++ b/src/diffusers/pipelines/pag/__init__.py
@@ -29,10 +29,14 @@
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
+ _import_structure["pipeline_pag_sana"] = ["SanaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
+ _import_structure["pipeline_pag_sd_3_img2img"] = ["StableDiffusion3PAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"]
+ _import_structure["pipeline_pag_sd_inpaint"] = ["StableDiffusionPAGInpaintPipeline"]
+
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
@@ -52,10 +56,13 @@
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_kolors import KolorsPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
+ from .pipeline_pag_sana import SanaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
+ from .pipeline_pag_sd_3_img2img import StableDiffusion3PAGImg2ImgPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline
+ from .pipeline_pag_sd_inpaint import StableDiffusionPAGInpaintPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py
index 7a6e30a3c6be..4cd2fe4cb79f 100644
--- a/src/diffusers/pipelines/pag/pag_utils.py
+++ b/src/diffusers/pipelines/pag/pag_utils.py
@@ -158,7 +158,7 @@ def set_pag_applied_layers(
),
):
r"""
- Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
+ Set the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
Args:
pag_applied_layers (`str` or `List[str]`):
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
index 00f960797d0e..bc90073cba77 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
@@ -25,24 +25,31 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
-from ..controlnet.multicontrolnet import MultiControlNetModel
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -252,7 +259,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1294,6 +1301,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
index f5f117ab7625..bc7a4b57affd 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
@@ -26,24 +26,31 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
-from ..controlnet.multicontrolnet import MultiControlNetModel
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -229,7 +236,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -1506,6 +1513,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
index 4cfb32d1de97..83540885bfb2 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
@@ -38,7 +38,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -61,8 +61,16 @@
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
-from ..controlnet.multicontrolnet import MultiControlNetModel
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -282,7 +290,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -423,7 +431,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -482,8 +492,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1562,6 +1574,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
index 66398483e046..b84f5d555914 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
@@ -38,7 +38,7 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -61,8 +61,16 @@
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
-from ..controlnet.multicontrolnet import MultiControlNetModel
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -272,7 +280,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -415,7 +423,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -474,8 +484,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1628,6 +1640,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
index 408992378538..a6a8deb5883c 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
@@ -245,9 +245,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
@@ -818,7 +816,11 @@ def __call__(
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
- self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
+ self.transformer.inner_dim // self.transformer.num_heads,
+ grid_crops_coords,
+ (grid_height, grid_width),
+ device=device,
+ output_type="pt",
)
style = torch.tensor([0], device=device)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
index 3e84f44adcf7..62f634312ada 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
@@ -202,12 +202,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
self.set_pag_applied_layers(pag_applied_layers)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
index 59d6a9001e1f..affda7e18add 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
@@ -29,6 +29,7 @@
deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -43,8 +44,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -172,7 +181,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.set_pag_applied_layers(pag_applied_layers)
@@ -227,13 +236,6 @@ def encode_prompt(
if device is None:
device = self._execution_device
- if prompt is not None and isinstance(prompt, str):
- batch_size = 1
- elif prompt is not None and isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
# See Section 3.1. of the paper.
max_length = max_sequence_length
@@ -278,12 +280,12 @@ def encode_prompt(
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
@@ -310,10 +312,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
@@ -805,10 +807,11 @@ def __call__(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
+ is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
@@ -850,6 +853,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
new file mode 100644
index 000000000000..030ab6db7391
--- /dev/null
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
@@ -0,0 +1,943 @@
+# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PixArtImageProcessor
+from ...models import AutoencoderDC, SanaTransformer2DModel
+from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pixart_alpha.pipeline_pixart_alpha import (
+ ASPECT_RATIO_512_BIN,
+ ASPECT_RATIO_1024_BIN,
+)
+from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
+from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN
+from .pag_utils import PAGMixin
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaPAGPipeline
+
+ >>> pipe = SanaPAGPipeline.from_pretrained(
+ ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
+ ... pag_applied_layers=["transformer_blocks.8"],
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe.to("cuda")
+ >>> pipe.text_encoder.to(torch.bfloat16)
+ >>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
+
+ >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
+ >>> image[0].save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
+ r"""
+ Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). This pipeline
+ supports the use of [Perturbed Attention Guidance
+ (PAG)](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag).
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: AutoencoderDC,
+ transformer: SanaTransformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ pag_applied_layers: Union[str, List[str]] = "transformer_blocks.0",
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
+ if hasattr(self, "vae") and self.vae is not None
+ else 8
+ )
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.set_pag_applied_layers(
+ pag_applied_layers,
+ pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
+ )
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0][:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ if self.transformer is not None:
+ dtype = self.transformer.dtype
+ elif self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ negative_prompt_attention_mask = uncond_input.attention_mask
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: int = 1024,
+ width: int = 1024,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ pag_scale: float = 3.0,
+ pag_adaptive_scale: float = 0.0,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
+ the requested resolution. Useful for generating non-square images.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+ pag_scale (`float`, *optional*, defaults to 3.0):
+ The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
+ guidance will not be used.
+ pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
+ The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
+ used.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 128:
+ aspect_ratio_bin = ASPECT_RATIO_4096_BIN
+ elif self.transformer.config.sample_size == 64:
+ aspect_ratio_bin = ASPECT_RATIO_2048_BIN
+ elif self.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ elif self.transformer.config.sample_size == 16:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ self._pag_scale = pag_scale
+ self._pag_adaptive_scale = pag_adaptive_scale
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ if self.do_perturbed_attention_guidance:
+ prompt_embeds = self._prepare_perturbed_attention_guidance(
+ prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
+ )
+ prompt_attention_mask = self._prepare_perturbed_attention_guidance(
+ prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance
+ )
+ elif self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+ if self.do_perturbed_attention_guidance:
+ original_attn_proc = self.transformer.attn_processors
+ self._set_pag_attn_processor(
+ pag_applied_layers=self.pag_applied_layers,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both
+ latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timestep,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if self.do_perturbed_attention_guidance:
+ noise_pred = self._apply_perturbed_attention_guidance(
+ noise_pred, self.do_classifier_free_guidance, guidance_scale, t
+ )
+ elif self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute previous image: x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ try:
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ except torch.cuda.OutOfMemoryError as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if self.do_perturbed_attention_guidance:
+ self.transformer.set_attn_processor(original_attn_proc)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
index 6220a00f2c22..fc7dc3a83f27 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
@@ -27,6 +27,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -39,8 +40,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -207,7 +216,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -221,7 +230,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -250,10 +259,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -281,7 +294,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1034,6 +1047,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
index c6f9077ad3da..fde3e500a573 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
@@ -200,9 +200,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -377,9 +375,9 @@ def encode_prompt(
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
- negative_prompt_2 (`str` or `List[str]`, *optional*):
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -693,7 +691,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -735,10 +733,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -890,7 +888,7 @@ def __call__(
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
new file mode 100644
index 000000000000..d64582a26f7a
--- /dev/null
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
@@ -0,0 +1,1056 @@
+# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import PIL.Image
+import torch
+from transformers import (
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import SD3Transformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
+from .pag_utils import PAGMixin
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusion3PAGImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = StableDiffusion3PAGImg2ImgPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-3-medium-diffusers",
+ ... torch_dtype=torch.float16,
+ ... pag_applied_layers=["blocks.13"],
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
+ >>> init_image = load_image(url).convert("RGB")
+ >>> image = pipe(prompt, image=init_image, guidance_scale=5.0, pag_scale=0.7).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin):
+ r"""
+ [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for image-to-image generation
+ using Stable Diffusion 3.
+
+ Args:
+ transformer ([`SD3Transformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
+ as its dimension.
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ text_encoder_3 ([`T5EncoderModel`]):
+ Frozen text-encoder. Stable Diffusion 3 uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_3 (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
+
+ def __init__(
+ self,
+ transformer: SD3Transformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer_2: CLIPTokenizer,
+ text_encoder_3: T5EncoderModel,
+ tokenizer_3: T5TokenizerFast,
+ pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ text_encoder_3=text_encoder_3,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ tokenizer_3=tokenizer_3,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 128
+ )
+ self.patch_size = (
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
+ )
+
+ self.set_pag_applied_layers(
+ pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0())
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 256,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if self.text_encoder_3 is None:
+ return torch.zeros(
+ (
+ batch_size * num_images_per_prompt,
+ self.tokenizer_max_length,
+ self.transformer.config.joint_attention_dim,
+ ),
+ device=device,
+ dtype=dtype,
+ )
+
+ text_inputs = self.tokenizer_3(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
+
+ dtype = self.text_encoder_3.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ clip_skip: Optional[int] = None,
+ clip_model_index: int = 0,
+ ):
+ device = device or self._execution_device
+
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
+
+ tokenizer = clip_tokenizers[clip_model_index]
+ text_encoder = clip_text_encoders[clip_model_index]
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+ pooled_prompt_embeds = prompt_embeds[0]
+
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ prompt_3: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ clip_skip: Optional[int] = None,
+ max_sequence_length: int = 256,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ prompt_3 = prompt_3 or prompt
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
+
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ clip_model_index=0,
+ )
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ prompt=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=clip_skip,
+ clip_model_index=1,
+ )
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
+
+ t5_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+ negative_prompt_3 = (
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
+ negative_prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=0,
+ )
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
+ negative_prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=1,
+ )
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
+
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
+ prompt=negative_prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
+ negative_clip_prompt_embeds,
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
+ )
+
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
+ negative_pooled_prompt_embeds = torch.cat(
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ strength,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ negative_prompt_3=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if (
+ height % (self.vae_scale_factor * self.patch_size) != 0
+ or width % (self.vae_scale_factor * self.patch_size) != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
+ )
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_3 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.prepare_latents
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+ if image.shape[1] == self.vae.config.latent_channels:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
+ latents = init_latents.to(device=device, dtype=dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ image: PipelineImageInput = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ pag_scale: float = 3.0,
+ pag_adaptive_scale: float = 0.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ will be used instead
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+ pag_scale (`float`, *optional*, defaults to 3.0):
+ The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
+ guidance will not be used.
+ pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
+ The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
+ used.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ height,
+ width,
+ strength,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+ self._pag_scale = pag_scale
+ self._pag_adaptive_scale = pag_adaptive_scale
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ clip_skip=self.clip_skip,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ if self.do_perturbed_attention_guidance:
+ prompt_embeds = self._prepare_perturbed_attention_guidance(
+ prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
+ )
+ pooled_prompt_embeds = self._prepare_perturbed_attention_guidance(
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance
+ )
+ elif self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 3. Preprocess image
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ if latents is None:
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ if self.do_perturbed_attention_guidance:
+ original_attn_proc = self.transformer.attn_processors
+ self._set_pag_attn_processor(
+ pag_applied_layers=self.pag_applied_layers,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both
+ latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_perturbed_attention_guidance:
+ noise_pred = self._apply_perturbed_attention_guidance(
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
+ )
+
+ elif self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if self.do_perturbed_attention_guidance:
+ self.transformer.set_attn_processor(original_attn_proc)
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusion3PipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
index 1e81fa3a158c..d3a015e569c1 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
@@ -26,6 +26,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,8 +41,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -147,7 +156,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
self.set_pag_applied_layers(pag_applied_layers)
@@ -847,6 +856,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
index 49dc4948cb40..d91c02b607a3 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -42,8 +43,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -202,7 +211,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -216,7 +225,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -245,10 +254,14 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -276,7 +289,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1063,6 +1076,12 @@ def __call__(
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
new file mode 100644
index 000000000000..33abfb0be89f
--- /dev/null
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
@@ -0,0 +1,1372 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from .pag_utils import PAGMixin
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import AutoPipelineForInpainting
+
+ >>> pipe = AutoPipelineForInpainting.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
+ ... )
+ >>> pipe = pipe.to("cuda")
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ >>> init_image = load_image(img_url).convert("RGB")
+ >>> mask_image = load_image(mask_url).convert("RGB")
+ >>> prompt = "A majestic tiger sitting on a bench"
+ >>> image = pipe(
+ ... prompt=prompt,
+ ... image=init_image,
+ ... mask_image=mask_image,
+ ... strength=0.8,
+ ... num_inference_steps=50,
+ ... guidance_scale=guidance_scale,
+ ... generator=generator,
+ ... pag_scale=pag_scale,
+ ... ).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionPAGInpaintPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+ PAGMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ pag_applied_layers: Union[str, List[str]] = "mid",
+ ):
+ super().__init__()
+
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ self.set_pag_applied_layers(pag_applied_layers)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ callback_steps,
+ output_type,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ image_latents = image
+ else:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ else:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.Tensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ pag_scale: float = 3.0,
+ pag_adaptive_scale: float = 0.0,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ pag_scale (`float`, *optional*, defaults to 3.0):
+ The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
+ guidance will not be used.
+ pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
+ The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
+ used.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hooks
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ height,
+ width,
+ strength,
+ None,
+ None,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+ self._pag_scale = pag_scale
+ self._pag_adaptive_scale = pag_adaptive_scale
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. set timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ )
+ if self.do_perturbed_attention_guidance:
+ if self.do_classifier_free_guidance:
+ mask, _ = mask.chunk(2)
+ masked_image_latents, _ = masked_image_latents.chunk(2)
+ mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance)
+ masked_image_latents = self._prepare_perturbed_attention_guidance(
+ masked_image_latents, masked_image_latents, self.do_classifier_free_guidance
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 9 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+
+ if self.do_perturbed_attention_guidance:
+ prompt_embeds = self._prepare_perturbed_attention_guidance(
+ prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
+ )
+ elif self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ for i, image_embeds in enumerate(ip_adapter_image_embeds):
+ negative_image_embeds = None
+ if self.do_classifier_free_guidance:
+ negative_image_embeds, image_embeds = image_embeds.chunk(2)
+ if self.do_perturbed_attention_guidance:
+ image_embeds = self._prepare_perturbed_attention_guidance(
+ image_embeds, negative_image_embeds, self.do_classifier_free_guidance
+ )
+
+ elif self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+ image_embeds = image_embeds.to(device)
+ ip_adapter_image_embeds[i] = image_embeds
+
+ # 9.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": ip_adapter_image_embeds}
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
+ else None
+ )
+
+ # 9.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ if self.do_perturbed_attention_guidance:
+ original_attn_proc = self.unet.attn_processors
+ self._set_pag_attn_processor(
+ pag_applied_layers=self.pag_applied_layers,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_perturbed_attention_guidance:
+ noise_pred = self._apply_perturbed_attention_guidance(
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
+ )
+
+ elif self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if self.do_perturbed_attention_guidance:
+ init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2)
+ else:
+ init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ condition_kwargs = {}
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
+ init_image_condition = init_image.clone()
+ init_image = self._encode_vae_image(init_image, generator=generator)
+ mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
+ condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
+ image = self.vae.decode(
+ latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
+ )[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if padding_mask_crop is not None:
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if self.do_perturbed_attention_guidance:
+ self.unet.set_attn_processor(original_attn_proc)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
index c2611164a049..856f6a3e789e 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
@@ -275,10 +275,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -415,7 +419,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -474,8 +480,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
index 6d634d524848..93dcca0ea9d6 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
@@ -298,7 +298,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -436,7 +436,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -495,8 +497,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index 7f85c13ac561..fdf3df2f4d6a 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -314,7 +314,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -526,7 +526,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -585,8 +587,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index b225fd71edf8..55a9f47145a2 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -23,7 +23,7 @@
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
@@ -31,6 +31,13 @@
from .image_encoder import PaintByExampleImageEncoder
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -209,7 +216,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -604,6 +611,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.maybe_free_model_hooks()
if not output_type == "latent":
diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py
index b7dfcd39edce..df8499ab900a 100644
--- a/src/diffusers/pipelines/pia/pipeline_pia.py
+++ b/src/diffusers/pipelines/pia/pipeline_pia.py
@@ -37,6 +37,7 @@
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -48,8 +49,16 @@
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -195,7 +204,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -928,6 +937,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py
index c4c212873a88..ec2f82bcf742 100644
--- a/src/diffusers/pipelines/pipeline_flax_utils.py
+++ b/src/diffusers/pipelines/pipeline_flax_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -180,7 +180,7 @@ class implements both a save and loading method. The pipeline is easily reloaded
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
- private = kwargs.pop("private", False)
+ private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -237,15 +237,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
```
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- - A string, the *repo id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained pipeline
- hosted on the Hub.
+ - A string, the *repo id* (for example `stable-diffusion-v1-5/stable-diffusion-v1-5`) of a
+ pretrained pipeline hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
using [`~FlaxDiffusionPipeline.save_pretrained`].
dtype (`str` or `jnp.dtype`, *optional*):
@@ -293,7 +293,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> # Requires to be logged in to Hugging Face hub,
>>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5",
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
... variant="bf16",
... dtype=jnp.bfloat16,
... )
@@ -301,7 +301,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> # Download pipeline, but use a different scheduler
>>> from diffusers import FlaxDPMSolverMultistepScheduler
- >>> model_id = "runwayml/stable-diffusion-v1-5"
+ >>> model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
>>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
... model_id,
... subfolder="scheduler",
@@ -559,7 +559,7 @@ def components(self) -> Dict[str, Any]:
... )
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
... )
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
```
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index c16bd8ac2069..f5b430564ca1 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,19 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
import importlib
import os
import re
import warnings
from pathlib import Path
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
+import requests
import torch
-from huggingface_hub import ModelCard, model_info
-from huggingface_hub.utils import validate_hf_hub_args
+from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
+from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version
+from requests.exceptions import HTTPError
from .. import __version__
from ..utils import (
@@ -38,14 +38,16 @@
is_accelerate_available,
is_peft_available,
is_transformers_available,
+ is_transformers_version,
logging,
)
from ..utils.torch_utils import is_compiled_module
+from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transformers_model_from_dduf
if is_transformers_available():
import transformers
- from transformers import PreTrainedModel
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
@@ -102,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
extension is replaced with ".safetensors"
"""
passed_components = passed_components or []
- if folder_names is not None:
+ if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
# extract all components of the pipeline and their associated files
@@ -118,6 +120,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components.setdefault(component, [])
components[component].append(component_filename)
+ # If there are no component folders check the main directory for safetensors files
+ if not components:
+ return any(".safetensors" in filename for filename in filenames)
+
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
@@ -135,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
return True
-def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
+def filter_model_files(filenames):
+ """Filter model repo files for just files/folders that contain model weights"""
+ weight_names = [
+ WEIGHTS_NAME,
+ SAFETENSORS_WEIGHTS_NAME,
+ FLAX_WEIGHTS_NAME,
+ ONNX_WEIGHTS_NAME,
+ ONNX_EXTERNAL_WEIGHTS_NAME,
+ ]
+
+ if is_transformers_available():
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+
+ allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
+
+ return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
+
+
+def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
@@ -163,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
+ legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
+ legacy_variant_index_re = re.compile(
+ rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
+ )
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
non_variant_file_re = re.compile(
@@ -171,33 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
# `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
- if variant is not None:
- variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
- variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
- variant_filenames = variant_weights | variant_indexes
- else:
- variant_filenames = set()
+ def filter_for_compatible_extensions(filenames, ignore_patterns=None):
+ if not ignore_patterns:
+ return filenames
- non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
- non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
- non_variant_filenames = non_variant_weights | non_variant_indexes
+ # ignore patterns uses glob style patterns e.g *.safetensors but we're only
+ # interested in the extension name
+ return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
- # all variant filenames will be used by default
- usable_filenames = set(variant_filenames)
+ def filter_with_regex(filenames, pattern_re):
+ return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
+
+ # Group files by component
+ components = {}
+ for filename in filenames:
+ if not len(filename.split("/")) == 2:
+ components.setdefault("", []).append(filename)
+ continue
+
+ component, _ = filename.split("/")
+ components.setdefault(component, []).append(filename)
+
+ usable_filenames = set()
+ variant_filenames = set()
+ for component, component_filenames in components.items():
+ component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
+
+ component_variants = set()
+ component_legacy_variants = set()
+ component_non_variants = set()
+ if variant is not None:
+ component_variants = filter_with_regex(component_filenames, variant_file_re)
+ component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
+
+ component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
+ component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
+
+ if component_variants or component_legacy_variants:
+ variant_filenames.update(
+ component_variants | component_variant_index_files
+ if component_variants
+ else component_legacy_variants | component_legacy_variant_index_files
+ )
- def convert_to_variant(filename):
- if "index" in filename:
- variant_filename = filename.replace("index", f"index.{variant}")
- elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
- variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
- variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
- return variant_filename
+ component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
+ component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
+
+ usable_filenames.update(component_non_variants | component_variant_index_files)
+
+ usable_filenames.update(variant_filenames)
- for f in non_variant_filenames:
- variant_filename = convert_to_variant(f)
- if variant_filename not in usable_filenames:
- usable_filenames.add(f)
+ if len(variant_filenames) == 0 and variant is not None:
+ error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
+ raise ValueError(error_message)
+
+ if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
+ logger.warning(
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
+ f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
+ f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
+ f"expected, please check your folder structure."
+ )
return usable_filenames, variant_filenames
@@ -529,6 +592,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
loaded_sub_model = passed_class_obj[name]
else:
+ sub_model_dtype = (
+ torch_dtype.get(name, torch_dtype.get("default", torch.float32))
+ if isinstance(torch_dtype, dict)
+ else torch_dtype
+ )
loaded_sub_model = _load_empty_model(
library_name=library_name,
class_name=class_name,
@@ -537,7 +605,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
name=name,
- torch_dtype=torch_dtype,
+ torch_dtype=sub_model_dtype,
cached_folder=kwargs.get("cached_folder", None),
force_download=kwargs.get("force_download", None),
proxies=kwargs.get("proxies", None),
@@ -553,7 +621,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
# Obtain a sorted dictionary for mapping the model-level components
# to their sizes.
module_sizes = {
- module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
+ module_name: compute_module_sizes(
+ module,
+ dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32))
+ if isinstance(torch_dtype, dict)
+ else torch_dtype,
+ )[""]
for module_name, module in init_empty_modules.items()
if isinstance(module, torch.nn.Module)
}
@@ -602,6 +675,8 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
+ dduf_entries: Optional[Dict[str, DDUFEntry]],
+ provider_options: Any,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
@@ -638,7 +713,7 @@ def load_sub_model(
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)
- load_method = getattr(class_obj, load_method_name)
+ load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None)
# add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0])
@@ -648,6 +723,7 @@ def load_sub_model(
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
+ loading_kwargs["provider_options"] = provider_options
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
@@ -696,7 +772,10 @@ def load_sub_model(
loading_kwargs["low_cpu_mem_usage"] = False
# check if the module is in a subdirectory
- if os.path.isdir(os.path.join(cached_folder, name)):
+ if dduf_entries:
+ loading_kwargs["dduf_entries"] = dduf_entries
+ loaded_sub_model = load_method(name, **loading_kwargs)
+ elif os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
@@ -721,6 +800,22 @@ def load_sub_model(
return loaded_sub_model
+def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable:
+ """
+ Return the method to load the sub model.
+
+ In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object
+ except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading
+ method that we need to use.
+ """
+ if is_dduf:
+ if issubclass(class_obj, PreTrainedTokenizerBase):
+ return lambda *args, **kwargs: _load_tokenizer_from_dduf(class_obj, *args, **kwargs)
+ if issubclass(class_obj, PreTrainedModel):
+ return lambda *args, **kwargs: _load_transformers_model_from_dduf(class_obj, *args, **kwargs)
+ return getattr(class_obj, load_method_name)
+
+
def _fetch_class_library_tuple(module):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
@@ -788,9 +883,9 @@ def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
- " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
+ " checkpoint: https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
- " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
+ " https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
@@ -873,10 +968,6 @@ def _get_custom_components_and_folders(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)
- if len(variant_filenames) == 0 and variant is not None:
- error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
- raise ValueError(error_message)
-
return custom_components, folder_names
@@ -884,7 +975,6 @@ def _get_ignore_patterns(
passed_components,
model_folder_names: List[str],
model_filenames: List[str],
- variant_filenames: List[str],
use_safetensors: bool,
from_flax: bool,
allow_pickle: bool,
@@ -915,16 +1005,6 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
- safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
- safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
- if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
- logger.warning(
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
- f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
- f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
- f"expected, please check your folder structure."
- )
-
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]
@@ -932,14 +1012,71 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
- bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
- bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
- if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
- logger.warning(
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
- f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
- f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
- f"your folder structure."
- )
-
return ignore_patterns
+
+
+def _download_dduf_file(
+ pretrained_model_name: str,
+ dduf_file: str,
+ pipeline_class_name: str,
+ cache_dir: str,
+ proxies: str,
+ local_files_only: bool,
+ token: str,
+ revision: str,
+):
+ model_info_call_error = None
+ if not local_files_only:
+ try:
+ info = model_info(pretrained_model_name, token=token, revision=revision)
+ except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
+ logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
+ local_files_only = True
+ model_info_call_error = e # save error to reraise it if model is not cached locally
+
+ if (
+ not local_files_only
+ and dduf_file is not None
+ and dduf_file not in (sibling.rfilename for sibling in info.siblings)
+ ):
+ raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
+
+ try:
+ user_agent = {"pipeline_class": pipeline_class_name, "dduf": True}
+ cached_folder = snapshot_download(
+ pretrained_model_name,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ allow_patterns=[dduf_file],
+ user_agent=user_agent,
+ )
+ return cached_folder
+ except FileNotFoundError:
+ # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
+ # This can happen in two cases:
+ # 1. If the user passed `local_files_only=True` => we raise the error directly
+ # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
+ if model_info_call_error is None:
+ # 1. user passed `local_files_only=True`
+ raise
+ else:
+ # 2. we forced `local_files_only=True` when `model_info` failed
+ raise EnvironmentError(
+ f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
+ " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
+ " above."
+ ) from model_info_call_error
+
+
+def _maybe_raise_error_for_incorrect_transformers(config_dict):
+ has_transformers_component = False
+ for k in config_dict:
+ if isinstance(config_dict[k], list):
+ has_transformers_component = config_dict[k][0] == "transformers"
+ if has_transformers_component:
+ break
+ if has_transformers_component and not is_transformers_version(">", "4.47.1"):
+ raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 2e1858b16148..66b56740ef13 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,16 +28,19 @@
import requests
import torch
from huggingface_hub import (
+ DDUFEntry,
ModelCard,
create_repo,
hf_hub_download,
model_info,
+ read_dduf_file,
snapshot_download,
)
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version
from requests.exceptions import HTTPError
from tqdm.auto import tqdm
+from typing_extensions import Self
from .. import __version__
from ..configuration_utils import ConfigMixin
@@ -51,6 +54,8 @@
DEPRECATED_REVISION_ARGS,
BaseOutput,
PushToHubMixin,
+ _get_detailed_type,
+ _is_valid_type,
is_accelerate_available,
is_accelerate_version,
is_torch_npu_available,
@@ -66,12 +71,12 @@
if is_torch_npu_available():
import torch_npu # noqa: F401
-
from .pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES,
CONNECTED_PIPES_KEYS,
CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES,
+ _download_dduf_file,
_fetch_class_library_tuple,
_get_custom_components_and_folders,
_get_custom_pipeline_class,
@@ -79,10 +84,12 @@
_get_ignore_patterns,
_get_pipeline_class,
_identify_model_variants,
+ _maybe_raise_error_for_incorrect_transformers,
_maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
+ filter_model_files,
load_sub_model,
maybe_raise_or_warn,
variant_compatible_siblings,
@@ -218,6 +225,7 @@ class implements both a save and loading method. The pipeline is easily reloaded
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
+
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -229,7 +237,7 @@ class implements both a save and loading method. The pipeline is easily reloaded
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
- private = kwargs.pop("private", False)
+ private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -318,7 +326,7 @@ def is_saveable_module(name, value):
create_pr=create_pr,
)
- def to(self, *args, **kwargs):
+ def to(self, *args, **kwargs) -> Self:
r"""
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
@@ -388,6 +396,8 @@ def to(self, *args, **kwargs):
)
device = device or device_arg
+ device_type = torch.device(device).type if device is not None else None
+ pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
@@ -410,20 +420,27 @@ def module_is_offloaded(module):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
- if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
- raise ValueError(
- "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
- )
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
- "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
+ "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
)
+ if device_type in ["cuda", "xpu"]:
+ if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
+ raise ValueError(
+ "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
+ )
+ # PR: https://github.com/huggingface/accelerate/pull/3223/
+ elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
+ raise ValueError(
+ "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
+ )
+
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
- if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
+ if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
@@ -435,6 +452,7 @@ def module_is_offloaded(module):
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for module in modules:
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
+ is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
logger.warning(
@@ -446,11 +464,21 @@ def module_is_offloaded(module):
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)
+ # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
+ # components can be from outside diffusers too, but still have group offloading enabled.
+ if (
+ self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
+ and device is not None
+ ):
+ logger.warning(
+ f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
+ )
+
# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
- elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
module.to(device, dtype)
if (
@@ -500,7 +528,7 @@ def dtype(self) -> torch.dtype:
@classmethod
@validate_hf_hub_args
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
r"""
Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights.
@@ -509,7 +537,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -523,9 +551,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
saved using
[`~DiffusionPipeline.save_pretrained`].
- torch_dtype (`str` or `torch.dtype`, *optional*):
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
+ torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
- dtype is automatically derived from the model's weights.
+ dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
+ `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
+ unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
+ torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
custom_pipeline (`str`, *optional*):
@@ -617,6 +649,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
+ dduf_file(`str`, *optional*):
+ Load weights from the specified dduf file.
@@ -636,7 +670,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> # Download pipeline that requires an authorization token
>>> # For more information on access tokens, please refer to this section
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
- >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # Use a different scheduler
>>> from diffusers import LMSDiscreteScheduler
@@ -660,16 +694,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
+ provider_options = kwargs.pop("provider_options", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
- offload_state_dict = kwargs.pop("offload_state_dict", False)
+ offload_state_dict = kwargs.pop("offload_state_dict", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
+ dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
+ if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
+ torch_dtype = torch.float32
+ logger.warning(
+ f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+ )
+
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
@@ -714,6 +756,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
+ if dduf_file:
+ if custom_pipeline:
+ raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
+ if load_connected_pipeline:
+ raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
+
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
@@ -736,6 +784,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
custom_pipeline=custom_pipeline,
custom_revision=custom_revision,
variant=variant,
+ dduf_file=dduf_file,
load_connected_pipeline=load_connected_pipeline,
**kwargs,
)
@@ -757,7 +806,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
logger.warning(warn_msg)
- config_dict = cls.load_config(cached_folder)
+ dduf_entries = None
+ if dduf_file:
+ dduf_file_path = os.path.join(cached_folder, dduf_file)
+ dduf_entries = read_dduf_file(dduf_file_path)
+ # The reader contains already all the files needed, no need to check it again
+ cached_folder = ""
+
+ config_dict = cls.load_config(cached_folder, dduf_entries=dduf_entries)
+
+ if dduf_file:
+ _maybe_raise_error_for_incorrect_transformers(config_dict)
# pop out "_ignore_files" as it is only needed for download
config_dict.pop("_ignore_files", None)
@@ -805,6 +864,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
+ expected_types = pipeline_class._get_signature_types()
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -893,6 +953,11 @@ def load_module(name, value):
loaded_sub_model = passed_class_obj[name]
else:
# load sub model
+ sub_model_dtype = (
+ torch_dtype.get(name, torch_dtype.get("default", torch.float32))
+ if isinstance(torch_dtype, dict)
+ else torch_dtype
+ )
loaded_sub_model = load_sub_model(
library_name=library_name,
class_name=class_name,
@@ -900,7 +965,7 @@ def load_module(name, value):
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
- torch_dtype=torch_dtype,
+ torch_dtype=sub_model_dtype,
provider=provider,
sess_options=sess_options,
device_map=current_device_map,
@@ -914,6 +979,8 @@ def load_module(name, value):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
+ dduf_entries=dduf_entries,
+ provider_options=provider_options,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -939,15 +1006,31 @@ def load_module(name, value):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
- passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - set(optional_kwargs)
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
- # 10. Instantiate the pipeline
+ # 10. Type checking init arguments
+ for kw, arg in init_kwargs.items():
+ # Too complex to validate with type annotation alone
+ if "scheduler" in kw:
+ continue
+ # Many tokenizer annotations don't include its "Fast" variant, so skip this
+ # e.g T5Tokenizer but not T5TokenizerFast
+ elif "tokenizer" in kw:
+ continue
+ elif (
+ arg is not None # Skip if None
+ and not expected_types[kw] == (inspect.Signature.empty,) # Skip if no type annotations
+ and not _is_valid_type(arg, expected_types[kw]) # Check type
+ ):
+ logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.")
+
+ # 11. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
- # 11. Save where the model was instantiated from
+ # 12. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None:
setattr(model, "hf_device_map", final_device_map)
@@ -964,6 +1047,19 @@ def _execution_device(self):
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
+ from ..hooks.group_offloading import _get_group_onload_device
+
+ # When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential
+ # offloading. We need to return the onload device of the group offloading hooks so that the intermediates
+ # required for computation (latents, prompt embeddings, etc.) can be created on the correct device.
+ for name, model in self.components.items():
+ if not isinstance(model, torch.nn.Module):
+ continue
+ try:
+ return _get_group_onload_device(model)
+ except ValueError:
+ pass
+
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
continue
@@ -1002,6 +1098,8 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
+ self._maybe_raise_error_if_group_offload_active(raise_error=True)
+
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
@@ -1077,11 +1175,20 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
def maybe_free_model_hooks(self):
r"""
- Function that offloads all components, removes all model hooks that were added when using
- `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
- is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
- functions correctly when applying enable_model_cpu_offload.
+ Method that performs the following:
+ - Offloads all components.
+ - Removes all model hooks that were added when using `enable_model_cpu_offload`, and then applies them again.
+ In case the model has not been offloaded, this function is a no-op.
+ - Resets stateful diffusers hooks of denoiser components if they were added with
+ [`~hooks.HookRegistry.register_hook`].
+
+ Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions
+ correctly when applying `enable_model_cpu_offload`.
"""
+ for component in self.components.values():
+ if hasattr(component, "_reset_stateful_cache"):
+ component._reset_stateful_cache()
+
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
# `enable_model_cpu_offload` has not be called, so silently do nothing
return
@@ -1104,6 +1211,8 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
+ self._maybe_raise_error_if_group_offload_active(raise_error=True)
+
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
@@ -1227,6 +1336,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
+ dduf_file(`str`, *optional*):
+ Load weights from the specified DDUF file.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
@@ -1267,11 +1378,26 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)
+ dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None)
+
+ if dduf_file:
+ if custom_pipeline:
+ raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
+ if load_connected_pipeline:
+ raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
+ return _download_dduf_file(
+ pretrained_model_name=pretrained_model_name,
+ dduf_file=dduf_file,
+ pipeline_class_name=cls.__name__,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ )
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
+ allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
+ use_safetensors = use_safetensors if use_safetensors is not None else True
allow_patterns = None
ignore_patterns = None
@@ -1286,6 +1412,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
model_info_call_error = e # save error to reraise it if model is not cached locally
if not local_files_only:
+ config_file = hf_hub_download(
+ pretrained_model_name,
+ cls.config_name,
+ cache_dir=cache_dir,
+ revision=revision,
+ proxies=proxies,
+ force_download=force_download,
+ token=token,
+ )
+ config_dict = cls._dict_from_json_file(config_file)
+ ignore_filenames = config_dict.pop("_ignore_files", [])
+
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
@@ -1300,60 +1438,20 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
logger.warning(warn_msg)
- model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
-
- config_file = hf_hub_download(
- pretrained_model_name,
- cls.config_name,
- cache_dir=cache_dir,
- revision=revision,
- proxies=proxies,
- force_download=force_download,
- token=token,
- )
-
- config_dict = cls._dict_from_json_file(config_file)
- ignore_filenames = config_dict.pop("_ignore_files", [])
-
- # remove ignored filenames
- model_filenames = set(model_filenames) - set(ignore_filenames)
- variant_filenames = set(variant_filenames) - set(ignore_filenames)
-
+ filenames = set(filenames) - set(ignore_filenames)
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
- warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
+ warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
custom_components, folder_names = _get_custom_components_and_folders(
- pretrained_model_name, config_dict, filenames, variant_filenames, variant
+ pretrained_model_name, config_dict, filenames, variant
)
- model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
-
custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1]
- # all filenames compatible with variant will be added
- allow_patterns = list(model_filenames)
-
- # allow all patterns from non-model folders
- # this enables downloading schedulers, tokenizers, ...
- allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
- # add custom component files
- allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
- # add custom pipeline file
- allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
- # also allow downloading config.json files with the model
- allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
-
- allow_patterns += [
- SCHEDULER_CONFIG_NAME,
- CONFIG_NAME,
- cls.config_name,
- CUSTOM_PIPELINE_FILE_NAME,
- ]
-
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0
@@ -1366,8 +1464,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
if load_components_from_hub and not trust_remote_code:
raise ValueError(
- f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
- f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
+ f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
+ f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
@@ -1386,12 +1484,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]
+ # retrieve the names of the folders containing model weights
+ model_folder_names = {
+ os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
+ }
# retrieve all patterns that should not be downloaded and error out when needed
ignore_patterns = _get_ignore_patterns(
passed_components,
model_folder_names,
- model_filenames,
- variant_filenames,
+ filenames,
use_safetensors,
from_flax,
allow_pickle,
@@ -1400,6 +1501,29 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
variant,
)
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+
+ # all filenames compatible with variant will be added
+ allow_patterns = list(model_filenames)
+
+ # allow all patterns from non-model folders
+ # this enables downloading schedulers, tokenizers, ...
+ allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
+ # add custom component files
+ allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
+ # add custom pipeline file
+ allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
+ # also allow downloading config.json files with the model
+ allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
+ allow_patterns += [
+ SCHEDULER_CONFIG_NAME,
+ CONFIG_NAME,
+ cls.config_name,
+ CUSTOM_PIPELINE_FILE_NAME,
+ ]
+
# Don't download any objects that are passed
allow_patterns = [
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
@@ -1442,7 +1566,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
user_agent=user_agent,
)
- # retrieve pipeline class from local file
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
@@ -1495,7 +1618,7 @@ def _get_signature_keys(cls, obj):
expected_modules.add(name)
optional_parameters.remove(name)
- return expected_modules, optional_parameters
+ return sorted(expected_modules), sorted(optional_parameters)
@classmethod
def _get_signature_types(cls):
@@ -1527,7 +1650,7 @@ def components(self) -> Dict[str, Any]:
... StableDiffusionInpaintPipeline,
... )
- >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> text2img = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
```
@@ -1537,10 +1660,12 @@ def components(self) -> Dict[str, Any]:
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
- if set(components.keys()) != expected_modules:
+ actual = sorted(set(components.keys()))
+ expected = sorted(expected_modules)
+ if actual != expected:
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
- f" {expected_modules} to be defined, but {components.keys()} are defined."
+ f" {expected} to be defined, but {actual} are defined."
)
return components
@@ -1552,6 +1677,7 @@ def numpy_to_pil(images):
"""
return numpy_to_pil(images)
+ @torch.compiler.disable
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
@@ -1659,7 +1785,7 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
>>> from diffusers import StableDiffusionPipeline
>>> pipe = StableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5",
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
... torch_dtype=torch.float16,
... use_safetensors=True,
... )
@@ -1706,13 +1832,13 @@ def from_pipe(cls, pipeline, **kwargs):
```py
>>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline
- >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
```
"""
original_config = dict(pipeline.config)
- torch_dtype = kwargs.pop("torch_dtype", None)
+ torch_dtype = kwargs.pop("torch_dtype", torch.float32)
# derive the pipeline class to instantiate
custom_pipeline = kwargs.pop("custom_pipeline", None)
@@ -1810,6 +1936,24 @@ def from_pipe(cls, pipeline, **kwargs):
return new_pipeline
+ def _maybe_raise_error_if_group_offload_active(
+ self, raise_error: bool = False, module: Optional[torch.nn.Module] = None
+ ) -> bool:
+ from ..hooks.group_offloading import _is_group_offload_enabled
+
+ components = self.components.values() if module is None else [module]
+ components = [component for component in components if isinstance(component, torch.nn.Module)]
+ for component in components:
+ if _is_group_offload_enabled(component):
+ if raise_error:
+ raise ValueError(
+ "You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
+ "with group offloading enabled. This is not supported. Please disable group offloading for "
+ "components of the pipeline to use other offloading methods."
+ )
+ return True
+ return False
+
class StableDiffusionMixin:
r"""
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index 46d8ad5e6dfa..988e049dd684 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -29,6 +29,7 @@
deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -36,8 +37,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -285,7 +294,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
@@ -338,13 +347,6 @@ def encode_prompt(
if device is None:
device = self._execution_device
- if prompt is not None and isinstance(prompt, str):
- batch_size = 1
- elif prompt is not None and isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
# See Section 3.1. of the paper.
max_length = max_sequence_length
@@ -389,12 +391,12 @@ def encode_prompt(
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
@@ -421,10 +423,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
@@ -905,10 +907,11 @@ def __call__(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
+ is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
@@ -938,8 +941,7 @@ def __call__(
# compute previous image: x_t -> x_t-1
if num_inference_steps == 1:
- # For DMD one step sampling: https://arxiv.org/abs/2311.18828
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
@@ -950,6 +952,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
index b2772d552514..7f10ee89ee04 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
@@ -29,6 +29,7 @@
deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -41,8 +42,16 @@
)
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -211,7 +220,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300
@@ -264,13 +273,6 @@ def encode_prompt(
if device is None:
device = self._execution_device
- if prompt is not None and isinstance(prompt, str):
- batch_size = 1
- elif prompt is not None and isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
# See Section 3.1. of the paper.
max_length = max_sequence_length
@@ -315,12 +317,12 @@ def encode_prompt(
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
@@ -347,10 +349,10 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
@@ -820,10 +822,11 @@ def __call__(
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
+ is_npu = latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
else:
- dtype = torch.int32 if is_mps else torch.int64
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
@@ -861,6 +864,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py
new file mode 100644
index 000000000000..1393b37e2d3a
--- /dev/null
+++ b/src/diffusers/pipelines/sana/__init__.py
@@ -0,0 +1,49 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_sana"] = ["SanaPipeline"]
+ _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_sana import SanaPipeline
+ from .pipeline_sana_sprint import SanaSprintPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/sana/pipeline_output.py b/src/diffusers/pipelines/sana/pipeline_output.py
new file mode 100644
index 000000000000..f8ac12951644
--- /dev/null
+++ b/src/diffusers/pipelines/sana/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class SanaPipelineOutput(BaseOutput):
+ """
+ Output class for Sana pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
new file mode 100644
index 000000000000..6093fd836aad
--- /dev/null
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -0,0 +1,1009 @@
+# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PixArtImageProcessor
+from ...loaders import SanaLoraLoaderMixin
+from ...models import AutoencoderDC, SanaTransformer2DModel
+from ...schedulers import DPMSolverMultistepScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..pixart_alpha.pipeline_pixart_alpha import (
+ ASPECT_RATIO_512_BIN,
+ ASPECT_RATIO_1024_BIN,
+)
+from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
+from .pipeline_output import SanaPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+ASPECT_RATIO_4096_BIN = {
+ "0.25": [2048.0, 8192.0],
+ "0.26": [2048.0, 7936.0],
+ "0.27": [2048.0, 7680.0],
+ "0.28": [2048.0, 7424.0],
+ "0.32": [2304.0, 7168.0],
+ "0.33": [2304.0, 6912.0],
+ "0.35": [2304.0, 6656.0],
+ "0.4": [2560.0, 6400.0],
+ "0.42": [2560.0, 6144.0],
+ "0.48": [2816.0, 5888.0],
+ "0.5": [2816.0, 5632.0],
+ "0.52": [2816.0, 5376.0],
+ "0.57": [3072.0, 5376.0],
+ "0.6": [3072.0, 5120.0],
+ "0.68": [3328.0, 4864.0],
+ "0.72": [3328.0, 4608.0],
+ "0.78": [3584.0, 4608.0],
+ "0.82": [3584.0, 4352.0],
+ "0.88": [3840.0, 4352.0],
+ "0.94": [3840.0, 4096.0],
+ "1.0": [4096.0, 4096.0],
+ "1.07": [4096.0, 3840.0],
+ "1.13": [4352.0, 3840.0],
+ "1.21": [4352.0, 3584.0],
+ "1.29": [4608.0, 3584.0],
+ "1.38": [4608.0, 3328.0],
+ "1.46": [4864.0, 3328.0],
+ "1.67": [5120.0, 3072.0],
+ "1.75": [5376.0, 3072.0],
+ "2.0": [5632.0, 2816.0],
+ "2.09": [5888.0, 2816.0],
+ "2.4": [6144.0, 2560.0],
+ "2.5": [6400.0, 2560.0],
+ "2.89": [6656.0, 2304.0],
+ "3.0": [6912.0, 2304.0],
+ "3.11": [7168.0, 2304.0],
+ "3.62": [7424.0, 2048.0],
+ "3.75": [7680.0, 2048.0],
+ "3.88": [7936.0, 2048.0],
+ "4.0": [8192.0, 2048.0],
+}
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaPipeline
+
+ >>> pipe = SanaPipeline.from_pretrained(
+ ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
+ ... )
+ >>> pipe.to("cuda")
+ >>> pipe.text_encoder.to(torch.bfloat16)
+ >>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
+
+ >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
+ >>> image[0].save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629).
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: AutoencoderDC,
+ transformer: SanaTransformer2DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
+ if hasattr(self, "vae") and self.vae is not None
+ else 32
+ )
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ dtype: torch.dtype,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_sequence_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if self.transformer is not None:
+ dtype = self.transformer.dtype
+ elif self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ prompt_embeds = prompt_embeds[:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=False,
+ )
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: int = 1024,
+ width: int = 1024,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ ) -> Union[SanaPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
+ the requested resolution. Useful for generating non-square images.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `300`):
+ Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 128:
+ aspect_ratio_bin = ASPECT_RATIO_4096_BIN
+ elif self.transformer.config.sample_size == 64:
+ aspect_ratio_bin = ASPECT_RATIO_2048_BIN
+ elif self.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ elif self.transformer.config.sample_size == 16:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
+ timestep = timestep * self.transformer.config.timestep_scale
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timestep,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ try:
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ except torch.cuda.OutOfMemoryError as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return SanaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
new file mode 100644
index 000000000000..9b3acbb1cb22
--- /dev/null
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
@@ -0,0 +1,889 @@
+# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PixArtImageProcessor
+from ...loaders import SanaLoraLoaderMixin
+from ...models import AutoencoderDC, SanaTransformer2DModel
+from ...schedulers import DPMSolverMultistepScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
+from .pipeline_output import SanaPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaSprintPipeline
+
+ >>> pipe = SanaSprintPipeline.from_pretrained(
+ ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0]
+ >>> image[0].save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: AutoencoderDC,
+ transformer: SanaTransformer2DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
+ if hasattr(self, "vae") and self.vae is not None
+ else 32
+ )
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ dtype: torch.dtype,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_sequence_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if self.transformer is not None:
+ dtype = self.transformer.dtype
+ elif self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ prompt_embeds = prompt_embeds[:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ num_inference_steps,
+ timesteps,
+ max_timesteps,
+ intermediate_timesteps,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if timesteps is not None and len(timesteps) != num_inference_steps + 1:
+ raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
+
+ if timesteps is not None and max_timesteps is not None:
+ raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
+
+ if timesteps is None and max_timesteps is None:
+ raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
+
+ if intermediate_timesteps is not None and num_inference_steps != 2:
+ raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 2,
+ timesteps: List[int] = None,
+ max_timesteps: float = 1.57080,
+ intermediate_timesteps: float = 1.3,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: int = 1024,
+ width: int = 1024,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ ) -> Union[SanaPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ max_timesteps (`float`, *optional*, defaults to 1.57080):
+ The maximum timestep value used in the SCM scheduler.
+ intermediate_timesteps (`float`, *optional*, defaults to 1.3):
+ The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
+ the requested resolution. Useful for generating non-square images.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `300`):
+ Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ max_timesteps=max_timesteps,
+ intermediate_timesteps=intermediate_timesteps,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=None,
+ max_timesteps=max_timesteps,
+ intermediate_timesteps=intermediate_timesteps,
+ )
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(0)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ latents = latents * self.scheduler.config.sigma_data
+
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
+ guidance = guidance * self.transformer.config.guidance_embeds_scale
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ timesteps = timesteps[:-1]
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
+ latents_model_input = latents / self.scheduler.config.sigma_data
+
+ scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
+
+ scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
+ latent_model_input = latents_model_input * torch.sqrt(
+ scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
+ )
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ guidance=guidance,
+ timestep=scm_timestep,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+
+ noise_pred = (
+ (1 - 2 * scm_timestep_expanded) * latent_model_input
+ + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
+ ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
+ noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
+
+ # compute previous image: x_t -> x_t-1
+ latents, denoised = self.scheduler.step(
+ noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
+ )
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ latents = denoised / self.scheduler.config.sigma_data
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ try:
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ except torch.cuda.OutOfMemoryError as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return SanaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
index 6f83071f3e85..a8c374259349 100644
--- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
+++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
@@ -9,12 +9,19 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import SemanticStableDiffusionPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -87,7 +94,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -701,6 +708,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
index f87f28e06c4a..ef8a95daefa4 100644
--- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
+++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
@@ -25,6 +25,7 @@
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -33,8 +34,16 @@
from .renderer import ShapERenderer
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -291,6 +300,9 @@ def __call__(
sample=latents,
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# Offload all models
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
index 7cc145e4c3e2..c0d1e38e0994 100644
--- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
+++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
@@ -24,6 +24,7 @@
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -32,8 +33,16 @@
from .renderer import ShapERenderer
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -278,6 +287,9 @@ def __call__(
sample=latents,
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["np", "pil", "latent", "mesh"]:
raise ValueError(
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
index 4fe082d88957..5d773b614a5c 100644
--- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
@@ -26,6 +26,7 @@
from ...models.embeddings import get_1d_rotary_pos_embed
from ...schedulers import EDMDPMSolverMultistepScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -34,6 +35,13 @@
from .modeling_stable_audio import StableAudioProjectionModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
@@ -438,7 +446,7 @@ def prepare_latents(
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
)
- audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
+ audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
# check num_channels
@@ -726,6 +734,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post-processing
if not output_type == "latent":
audio = self.vae.decode(latents).sample
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
index 111ccc40c5a5..38f1c4314e4f 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
@@ -15,18 +15,26 @@
from typing import Callable, Dict, List, Optional, Union
import torch
-from transformers import CLIPTextModel, CLIPTokenizer
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import is_torch_version, logging, replace_example_docstring
+from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -57,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
Args:
tokenizer (`CLIPTokenizer`):
The CLIP tokenizer.
- text_encoder (`CLIPTextModel`):
+ text_encoder (`CLIPTextModelWithProjection`):
The CLIP text encoder.
decoder ([`StableCascadeUNet`]):
The Stable Cascade decoder unet.
@@ -85,7 +93,7 @@ def __init__(
self,
decoder: StableCascadeUNet,
tokenizer: CLIPTokenizer,
- text_encoder: CLIPTextModel,
+ text_encoder: CLIPTextModelWithProjection,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
latent_dim_scale: float = 10.67,
@@ -503,6 +511,9 @@ def __call__(
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
index 6724b60cc424..28a74ab83733 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
@@ -15,7 +15,7 @@
import PIL
import torch
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
@@ -52,7 +52,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
Args:
tokenizer (`CLIPTokenizer`):
The decoder tokenizer to be used for text inputs.
- text_encoder (`CLIPTextModel`):
+ text_encoder (`CLIPTextModelWithProjection`):
The decoder text encoder to be used for text inputs.
decoder (`StableCascadeUNet`):
The decoder model to be used for decoder image generation pipeline.
@@ -60,14 +60,18 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
The scheduler to be used for decoder image generation pipeline.
vqgan (`PaellaVQModel`):
The VQGAN model to be used for decoder image generation pipeline.
- feature_extractor ([`~transformers.CLIPImageProcessor`]):
- Model that extracts features from generated images to be used as inputs for the `image_encoder`.
- image_encoder ([`CLIPVisionModelWithProjection`]):
- Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
prior_prior (`StableCascadeUNet`):
The prior model to be used for prior pipeline.
+ prior_text_encoder (`CLIPTextModelWithProjection`):
+ The prior text encoder to be used for text inputs.
+ prior_tokenizer (`CLIPTokenizer`):
+ The prior tokenizer to be used for text inputs.
prior_scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for prior pipeline.
+ prior_feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `image_encoder`.
+ prior_image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
"""
_load_connected_pipes = True
@@ -76,12 +80,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
def __init__(
self,
tokenizer: CLIPTokenizer,
- text_encoder: CLIPTextModel,
+ text_encoder: CLIPTextModelWithProjection,
decoder: StableCascadeUNet,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
prior_prior: StableCascadeUNet,
- prior_text_encoder: CLIPTextModel,
+ prior_text_encoder: CLIPTextModelWithProjection,
prior_tokenizer: CLIPTokenizer,
prior_scheduler: DDPMWuerstchenScheduler,
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
index 058dbf6b0797..241c454e103e 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
@@ -23,13 +23,21 @@
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import BaseOutput, logging, replace_example_docstring
+from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """
@@ -611,6 +619,9 @@ def __call__(
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# Offload all models
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
index 53dc98aea698..4cc4eabd4a40 100644
--- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
index 5d6ffd463cc3..eaeb5f809c47 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -55,7 +55,7 @@
>>> from diffusers import FlaxStableDiffusionPipeline
>>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16
... )
>>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -100,8 +100,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -132,17 +132,21 @@ def __init__(
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -162,7 +166,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
index 7792bc097595..c2d918156084 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
@@ -124,8 +124,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -165,7 +165,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
if not isinstance(prompt, (str, list)):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index f6bb0ac299b3..abcba926160a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -127,8 +127,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -159,17 +159,21 @@ def __init__(
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -189,7 +193,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_inputs(
self,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
index 2e34dcb83c01..9917276e0a1f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
@@ -57,7 +57,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -71,7 +71,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
index c39409886bd9..92c82d61b8f2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
@@ -78,7 +78,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -109,7 +110,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -123,7 +124,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index 18d8050826cc..ddd2e27dedaf 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -76,7 +76,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -108,7 +109,7 @@ def __init__(
super().__init__()
logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -122,7 +123,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
index cd9ec57fb879..ef84cdd38b6d 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
@@ -83,7 +83,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -97,7 +97,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 4fd6a43a955a..6e93c34929de 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -55,7 +55,9 @@
>>> import torch
>>> from diffusers import StableDiffusionPipeline
- >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionPipeline.from_pretrained(
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... )
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -184,8 +186,8 @@ class StableDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -209,7 +211,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -223,7 +225,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -252,17 +254,25 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ self._is_unet_config_sample_size_int = unet is not None and isinstance(unet.config.sample_size, int)
+ is_unet_sample_size_less_64 = (
+ unet is not None
+ and hasattr(unet.config, "sample_size")
+ and self._is_unet_config_sample_size_int
+ and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -283,7 +293,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -902,8 +912,18 @@ def __call__(
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
- height = height or self.unet.config.sample_size * self.vae_scale_factor
- width = width or self.unet.config.sample_size * self.vae_scale_factor
+ if not height or not width:
+ height = (
+ self.unet.config.sample_size
+ if self._is_unet_config_sample_size_int
+ else self.unet.config.sample_size[0]
+ )
+ width = (
+ self.unet.config.sample_size
+ if self._is_unet_config_sample_size_int
+ else self.unet.config.sample_size[1]
+ )
+ height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index 7801b0d01dff..f158c41cac53 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -28,11 +28,26 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -115,17 +130,21 @@ def __init__(
):
super().__init__()
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -145,7 +164,7 @@ def __init__(
depth_estimator=depth_estimator,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -861,6 +880,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
index 93a8bd160318..e0268065a415 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -24,13 +24,20 @@
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -57,8 +64,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -97,17 +104,21 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -126,7 +137,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -401,6 +412,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.maybe_free_model_hooks()
if not output_type == "latent":
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 9cd5673c9359..901dcd6db012 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -32,6 +32,7 @@
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -43,8 +44,16 @@
from .safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -56,7 +65,7 @@
>>> from diffusers import StableDiffusionImg2ImgPipeline
>>> device = "cuda"
- >>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
+ >>> model_id_or_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
>>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
>>> pipe = pipe.to(device)
@@ -205,8 +214,8 @@ class StableDiffusionImg2ImgPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -230,7 +239,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -244,7 +253,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -273,17 +282,21 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -304,7 +317,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1120,6 +1133,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 49c38c800942..6f4e7f358952 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -27,13 +27,27 @@
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -146,8 +160,8 @@ class StableDiffusionInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -171,7 +185,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -185,7 +199,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -215,17 +229,21 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -237,7 +255,7 @@ def __init__(
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
- if unet.config.in_channels != 9:
+ if unet is not None and unet.config.in_channels != 9:
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
self.register_modules(
@@ -250,7 +268,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -1014,7 +1032,7 @@ def __call__(
>>> mask_image = download_image(mask_url).resize((512, 512))
>>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
@@ -1200,7 +1218,7 @@ def __call__(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
@@ -1303,6 +1321,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
condition_kwargs = {}
if isinstance(self.vae, AsymmetricAutoencoderKL):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index fd89b195c778..7857bc58a8ad 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -22,16 +22,23 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import PIL_INTERPOLATION, deprecate, logging
+from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline(
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
IPAdapterMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
@@ -106,8 +114,8 @@ class StableDiffusionInstructPix2PixPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -157,7 +165,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -457,6 +465,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index ffe02ae679e5..c6967bc393b5 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -25,11 +25,18 @@
from ...loaders import FromSingleFileMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -116,7 +123,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
def _encode_prompt(
@@ -640,6 +647,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index 4cbbe17531ef..dae4540ebe00 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -30,12 +30,26 @@
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -149,7 +163,7 @@ def __init__(
watermarker=watermarker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
self.register_to_config(max_noise_level=max_noise_level)
@@ -769,6 +783,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index 41811f8f2c0e..be01e0acbf18 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -28,6 +28,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -38,8 +39,16 @@
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -132,7 +141,7 @@ def __init__(
image_noising_scheduler: KarrasDiffusionSchedulers,
# regular denoising components
tokenizer: CLIPTokenizer,
- text_encoder: CLIPTextModelWithProjection,
+ text_encoder: CLIPTextModel,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
# vae
@@ -154,7 +163,7 @@ def __init__(
vae=vae,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder
@@ -924,6 +933,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
index 2556d5e57b6d..eac9945ff349 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
@@ -28,6 +28,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -38,8 +39,16 @@
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -155,7 +164,7 @@ def __init__(
vae=vae,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -829,6 +838,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index 4b9df578bc4a..4618d384cbd7 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,12 +19,14 @@
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
+ SiglipImageProcessor,
+ SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
-from ...image_processor import VaeImageProcessor
-from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -68,6 +70,20 @@
"""
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -128,7 +144,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
+class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -160,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ image_encoder (`SiglipVisionModel`, *optional*):
+ Pre-trained Vision Model for IP Adapter.
+ feature_extractor (`SiglipImageProcessor`, *optional*):
+ Image processor for IP Adapter.
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -177,6 +197,8 @@ def __init__(
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
+ image_encoder: SiglipVisionModel = None,
+ feature_extractor: SiglipImageProcessor = None,
):
super().__init__()
@@ -190,10 +212,10 @@ def __init__(
tokenizer_3=tokenizer_3,
transformer=transformer,
scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -361,9 +383,9 @@ def encode_prompt(
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
- negative_prompt_2 (`str` or `List[str]`, *optional*):
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -642,6 +664,10 @@ def prepare_latents(
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def skip_guidance_layers(self):
+ return self._skip_guidance_layers
+
@property
def clip_skip(self):
return self._clip_skip
@@ -665,6 +691,83 @@ def num_timesteps(self):
def interrupt(self):
return self._interrupt
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
+
+ Args:
+ image (`PipelineImageInput`):
+ Input image to be encoded.
+ device: (`torch.device`):
+ Torch device.
+
+ Returns:
+ `torch.Tensor`: The encoded image feature representation.
+ """
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=self.dtype)
+
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ ) -> torch.Tensor:
+ """Prepares image embeddings for use in the IP-Adapter.
+
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
+
+ Args:
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ The input image to extract features from for IP-Adapter.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Precomputed image embeddings.
+ device: (`torch.device`, *optional*):
+ Torch device.
+ num_images_per_prompt (`int`, defaults to 1):
+ Number of images that should be generated per prompt.
+ do_classifier_free_guidance (`bool`, defaults to True):
+ Whether to use classifier free guidance or not.
+ """
+ device = device or self._execution_device
+
+ if ip_adapter_image_embeds is not None:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
+ else:
+ single_image_embeds = ip_adapter_image_embeds
+ elif ip_adapter_image is not None:
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
+ else:
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
+
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+
+ return image_embeds.to(device=device)
+
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
+ logger.warning(
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
+ )
+
+ super().enable_sequential_cpu_offload(*args, **kwargs)
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -675,7 +778,7 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -687,6 +790,8 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -694,6 +799,11 @@ def __call__(
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
+ skip_guidance_layers: List[int] = None,
+ skip_layer_guidance_scale: float = 2.8,
+ skip_layer_guidance_stop: float = 0.2,
+ skip_layer_guidance_start: float = 0.01,
+ mu: Optional[float] = None,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -715,10 +825,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -758,12 +868,18 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
- of a plain tuple.
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
+ a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -778,6 +894,23 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+ skip_guidance_layers (`List[int]`, *optional*):
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
+ with a scale of `1`.
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples:
@@ -809,6 +942,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
@@ -851,15 +985,13 @@ def __call__(
)
if self.do_classifier_free_guidance:
+ if skip_guidance_layers is not None:
+ original_prompt_embeds = prompt_embeds
+ original_pooled_prompt_embeds = pooled_prompt_embeds
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
- # 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
- self._num_timesteps = len(timesteps)
-
- # 5. Prepare latent variables
+ # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
@@ -872,7 +1004,49 @@ def __call__(
latents,
)
- # 6. Denoising loop
+ # 5. Prepare timesteps
+ scheduler_kwargs = {}
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
+ _, _, height, width = latents.shape
+ image_seq_len = (height // self.transformer.config.patch_size) * (
+ width // self.transformer.config.patch_size
+ )
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.16),
+ )
+ scheduler_kwargs["mu"] = mu
+ elif mu is not None:
+ scheduler_kwargs["mu"] = mu
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare image embeddings
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
+ else:
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
+
+ # 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
@@ -896,6 +1070,27 @@ def __call__(
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ should_skip_layers = (
+ True
+ if i > num_inference_steps * skip_layer_guidance_start
+ and i < num_inference_steps * skip_layer_guidance_stop
+ else False
+ )
+ if skip_guidance_layers is not None and should_skip_layers:
+ timestep = t.expand(latents.shape[0])
+ latent_model_input = latents
+ noise_pred_skip_layers = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=original_prompt_embeds,
+ pooled_projections=original_pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ skip_layers=skip_guidance_layers,
+ )[0]
+ noise_pred = (
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
+ )
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
index 794716303394..19bdc9792e23 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
@@ -20,12 +20,14 @@
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
+ SiglipImageProcessor,
+ SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -75,6 +77,20 @@
"""
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -149,7 +165,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
+class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -181,10 +197,14 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ image_encoder (`SiglipVisionModel`, *optional*):
+ Pre-trained Vision Model for IP Adapter.
+ feature_extractor (`SiglipImageProcessor`, *optional*):
+ Image processor for IP Adapter.
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -198,6 +218,8 @@ def __init__(
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
+ image_encoder: Optional[SiglipVisionModel] = None,
+ feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
@@ -211,13 +233,25 @@ def __init__(
tokenizer_3=tokenizer_3,
transformer=transformer,
scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
- vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 128
+ )
+ self.patch_size = (
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
)
- self.tokenizer_max_length = self.tokenizer.model_max_length
- self.default_sample_size = self.transformer.config.sample_size
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
@@ -376,9 +410,9 @@ def encode_prompt(
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
- negative_prompt_2 (`str` or `List[str]`, *optional*):
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -531,6 +565,8 @@ def check_inputs(
prompt,
prompt_2,
prompt_3,
+ height,
+ width,
strength,
negative_prompt=None,
negative_prompt_2=None,
@@ -542,6 +578,15 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
+ if (
+ height % (self.vae_scale_factor * self.patch_size) != 0
+ or width % (self.vae_scale_factor * self.patch_size) != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
+ )
+
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -703,6 +748,84 @@ def num_timesteps(self):
def interrupt(self):
return self._interrupt
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
+
+ Args:
+ image (`PipelineImageInput`):
+ Input image to be encoded.
+ device: (`torch.device`):
+ Torch device.
+
+ Returns:
+ `torch.Tensor`: The encoded image feature representation.
+ """
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=self.dtype)
+
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ ) -> torch.Tensor:
+ """Prepares image embeddings for use in the IP-Adapter.
+
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
+
+ Args:
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ The input image to extract features from for IP-Adapter.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Precomputed image embeddings.
+ device: (`torch.device`, *optional*):
+ Torch device.
+ num_images_per_prompt (`int`, defaults to 1):
+ Number of images that should be generated per prompt.
+ do_classifier_free_guidance (`bool`, defaults to True):
+ Whether to use classifier free guidance or not.
+ """
+ device = device or self._execution_device
+
+ if ip_adapter_image_embeds is not None:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
+ else:
+ single_image_embeds = ip_adapter_image_embeds
+ elif ip_adapter_image is not None:
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
+ else:
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
+
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+
+ return image_embeds.to(device=device)
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
+ logger.warning(
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
+ )
+
+ super().enable_sequential_cpu_offload(*args, **kwargs)
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -710,10 +833,12 @@ def __call__(
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
image: PipelineImageInput = None,
strength: float = 0.6,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -726,12 +851,15 @@ def __call__(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
+ mu: Optional[float] = None,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -746,17 +874,17 @@ def __call__(
prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
will be used instead
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -796,12 +924,18 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
- of a plain tuple.
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
+ a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -816,6 +950,7 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples:
@@ -824,12 +959,16 @@ def __call__(
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
prompt_3,
+ height,
+ width,
strength,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
@@ -890,10 +1029,27 @@ def __call__(
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Preprocess image
- image = self.image_processor.preprocess(image)
+ image = self.image_processor.preprocess(image, height=height, width=width)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ scheduler_kwargs = {}
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
+ image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
+ int(width) // self.vae_scale_factor // self.transformer.config.patch_size
+ )
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.16),
+ )
+ scheduler_kwargs["mu"] = mu
+ elif mu is not None:
+ scheduler_kwargs["mu"] = mu
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
+ )
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
@@ -909,7 +1065,22 @@ def __call__(
generator,
)
- # 6. Denoising loop
+ # 6. Prepare image embeddings
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
+ else:
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
+
+ # 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index 7401be39d6f9..c69fb90a4c5e 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -13,19 +13,21 @@
# limitations under the License.
import inspect
-from typing import Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
+ SiglipImageProcessor,
+ SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -74,6 +76,20 @@
"""
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -148,7 +164,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
+class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -180,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ image_encoder (`SiglipVisionModel`, *optional*):
+ Pre-trained Vision Model for IP Adapter.
+ feature_extractor (`SiglipImageProcessor`, *optional*):
+ Image processor for IP Adapter.
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -197,6 +217,8 @@ def __init__(
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
+ image_encoder: Optional[SiglipVisionModel] = None,
+ feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
@@ -210,20 +232,32 @@ def __init__(
tokenizer_3=tokenizer_3,
transformer=transformer,
scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
- vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
+ vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
- self.tokenizer_max_length = self.tokenizer.model_max_length
- self.default_sample_size = self.transformer.config.sample_size
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 128
+ )
+ self.patch_size = (
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
+ )
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
@@ -382,9 +416,9 @@ def encode_prompt(
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
- negative_prompt_2 (`str` or `List[str]`, *optional*):
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -538,6 +572,8 @@ def check_inputs(
prompt,
prompt_2,
prompt_3,
+ height,
+ width,
strength,
negative_prompt=None,
negative_prompt_2=None,
@@ -549,6 +585,15 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
+ if (
+ height % (self.vae_scale_factor * self.patch_size) != 0
+ or width % (self.vae_scale_factor * self.patch_size) != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
+ )
+
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -783,6 +828,10 @@ def clip_skip(self):
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
@property
def num_timesteps(self):
return self._num_timesteps
@@ -791,6 +840,84 @@ def num_timesteps(self):
def interrupt(self):
return self._interrupt
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
+
+ Args:
+ image (`PipelineImageInput`):
+ Input image to be encoded.
+ device: (`torch.device`):
+ Torch device.
+
+ Returns:
+ `torch.Tensor`: The encoded image feature representation.
+ """
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=self.dtype)
+
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ ) -> torch.Tensor:
+ """Prepares image embeddings for use in the IP-Adapter.
+
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
+
+ Args:
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ The input image to extract features from for IP-Adapter.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Precomputed image embeddings.
+ device: (`torch.device`, *optional*):
+ Torch device.
+ num_images_per_prompt (`int`, defaults to 1):
+ Number of images that should be generated per prompt.
+ do_classifier_free_guidance (`bool`, defaults to True):
+ Whether to use classifier free guidance or not.
+ """
+ device = device or self._execution_device
+
+ if ip_adapter_image_embeds is not None:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
+ else:
+ single_image_embeds = ip_adapter_image_embeds
+ elif ip_adapter_image is not None:
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
+ else:
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
+
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+
+ return image_embeds.to(device=device)
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
+ logger.warning(
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
+ )
+
+ super().enable_sequential_cpu_offload(*args, **kwargs)
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -806,7 +933,7 @@ def __call__(
padding_mask_crop: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
+ sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -818,12 +945,16 @@ def __call__(
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
+ mu: Optional[float] = None,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -854,9 +985,9 @@ def __call__(
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
latents tensor will ge generated by `mask_image`.
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
@@ -874,10 +1005,10 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -917,12 +1048,22 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
- of a plain tuple.
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
+ a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -933,6 +1074,7 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples:
@@ -953,6 +1095,8 @@ def __call__(
prompt,
prompt_2,
prompt_3,
+ height,
+ width,
strength,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
@@ -967,6 +1111,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
@@ -1007,7 +1152,24 @@ def __call__(
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ scheduler_kwargs = {}
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
+ image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
+ int(width) // self.vae_scale_factor // self.transformer.config.patch_size
+ )
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.16),
+ )
+ scheduler_kwargs["mu"] = mu
+ elif mu is not None:
+ scheduler_kwargs["mu"] = mu
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
+ )
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
# check that number of inference steps is not < 1 - as this doesn't make sense
if num_inference_steps < 1:
@@ -1104,7 +1266,22 @@ def __call__(
f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
)
- # 7. Denoising loop
+ # 7. Prepare image embeddings
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
+ else:
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
+
+ # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1125,6 +1302,7 @@ def __call__(
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
index 8f40fa72a25c..351b146fb423 100644
--- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -41,6 +42,14 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
@@ -194,8 +203,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -242,7 +251,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1008,6 +1017,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
index 2b86470dbff1..4b999662a6e7 100644
--- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
+++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
@@ -33,6 +33,7 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -44,6 +45,13 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -268,8 +276,8 @@ class StableDiffusionDiffEditPipeline(
A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents.
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -292,7 +300,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -306,7 +314,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -336,17 +344,21 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -367,7 +379,7 @@ def __init__(
feature_extractor=feature_extractor,
inverse_scheduler=inverse_scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1508,6 +1520,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
index 52ccd5612776..4bbb93e44a83 100644
--- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
+++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
@@ -29,6 +29,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,8 +41,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -120,8 +129,8 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -168,7 +177,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -828,6 +837,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
index c6748ad418fe..86ef01784057 100644
--- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
+++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
@@ -32,7 +32,14 @@
from ...models.attention import GatedSelfAttentionDense
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
@@ -40,8 +47,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -172,8 +187,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -226,7 +241,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -446,13 +461,14 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
- # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
+ gligen_images,
+ gligen_phrases,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
@@ -499,6 +515,13 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)
+ if gligen_images is not None and gligen_phrases is not None:
+ if len(gligen_images) != len(gligen_phrases):
+ raise ValueError(
+ "`gligen_images` and `gligen_phrases` must have the same length when both are provided, but"
+ f" got: `gligen_images` with length {len(gligen_images)} != `gligen_phrases` with length {len(gligen_phrases)}."
+ )
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
@@ -814,6 +837,8 @@ def __call__(
height,
width,
callback_steps,
+ gligen_images,
+ gligen_phrases,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
@@ -1000,6 +1025,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
index 122701ff923f..1f29f577f8e0 100755
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -19,15 +19,31 @@
import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPTokenizerFast,
+)
from ...image_processor import VaeImageProcessor
-from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import (
+ StableDiffusionLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
-from ...schedulers import LMSDiscreteScheduler
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
-from ..stable_diffusion import StableDiffusionPipelineOutput
+from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -83,7 +99,8 @@ class StableDiffusionKDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -94,13 +111,13 @@ class StableDiffusionKDiffusionPipeline(
def __init__(
self,
- vae,
- text_encoder,
- tokenizer,
- unet,
- scheduler,
- safety_checker,
- feature_extractor,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -124,7 +141,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
index 45f814fd538f..c7c5bd9cff67 100644
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
@@ -170,10 +170,14 @@ def __init__(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
model = ModelWrapper(unet, scheduler.alphas_cumprod)
if scheduler.config.prediction_type == "v_prediction":
@@ -321,7 +325,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -380,8 +386,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
index 81bb0e9a7270..702f3eda5816 100644
--- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
+++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
@@ -30,6 +30,7 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,8 +41,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```python
@@ -203,8 +212,8 @@ class StableDiffusionLDM3DPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -254,7 +263,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1002,6 +1011,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
index 2fc79c0610f0..ccee6d47b47a 100644
--- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
@@ -26,6 +26,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -37,8 +38,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -179,8 +188,8 @@ class StableDiffusionPanoramaPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -230,7 +239,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1155,6 +1164,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type != "latent":
if circular_padding:
image = self.decode_latents_with_padding(latents)
diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
index cd59cf51869d..deae82eb8813 100644
--- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
+++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
@@ -12,13 +12,20 @@
from ...loaders import IPAdapterMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -46,8 +53,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -74,7 +81,7 @@ def __init__(
" abuse, brutality, cruelty"
)
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -88,7 +95,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -117,17 +124,21 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
- version.parse(unet.config._diffusers_version).base_version
- ) < version.parse("0.9.0.dev0")
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ is_unet_version_less_0_9_0 = (
+ unet is not None
+ and hasattr(unet.config, "_diffusers_version")
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
+ )
+ is_unet_sample_size_less_64 = (
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -149,7 +160,7 @@ def __init__(
image_encoder=image_encoder,
)
self._safety_text_concept = safety_concept
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
@property
@@ -739,6 +750,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
image = self.decode_latents(latents)
diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
index c32052d2e4d0..e96422073b19 100644
--- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
@@ -27,6 +27,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -38,8 +39,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -47,7 +56,7 @@
>>> from diffusers import StableDiffusionSAGPipeline
>>> pipe = StableDiffusionSAGPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
@@ -123,8 +132,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -157,7 +166,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -840,6 +849,9 @@ def get_map_size(module, input, output):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
index 77363b2546d7..eb1030f3bb9d 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
@@ -65,7 +65,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index a4757ac2f336..9c69fe65fbdb 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -237,11 +237,8 @@ class StableDiffusionXLPipeline(
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
- "negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
- "negative_pooled_prompt_embeds",
- "negative_add_time_ids",
]
def __init__(
@@ -272,10 +269,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -409,7 +410,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -468,8 +471,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1243,13 +1248,8 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
- negative_pooled_prompt_embeds = callback_outputs.pop(
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
- )
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
- negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 50688ddb1cb8..08d0b44d613d 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -257,11 +257,8 @@ class StableDiffusionXLImg2ImgPipeline(
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
- "negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
- "negative_pooled_prompt_embeds",
- "add_neg_time_ids",
]
def __init__(
@@ -294,7 +291,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -430,7 +427,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -489,8 +488,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1438,13 +1439,8 @@ def denoising_value_valid(dnv):
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
- negative_pooled_prompt_embeds = callback_outputs.pop(
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
- )
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index c7c706350e8e..920caf4d24a1 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -285,11 +285,8 @@ class StableDiffusionXLInpaintPipeline(
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
- "negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
- "negative_pooled_prompt_embeds",
- "add_neg_time_ids",
"mask",
"masked_image_latents",
]
@@ -324,7 +321,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -534,7 +531,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -593,8 +592,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1671,13 +1672,8 @@ def denoising_value_valid(dnv):
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
- negative_pooled_prompt_embeds = callback_outputs.pop(
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
- )
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
mask = callback_outputs.pop("mask", mask)
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
index b59f2312726d..aaffe8efa730 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
@@ -199,9 +199,13 @@ def __init__(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
self.is_cosxl_edit = is_cosxl_edit
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -333,7 +337,9 @@ def encode_prompt(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -385,7 +391,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
index fb986075aeea..8c1af7863e63 100644
--- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
+++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
@@ -24,14 +24,22 @@
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
-from ...utils import BaseOutput, logging, replace_example_docstring
+from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -177,7 +185,7 @@ def __init__(
scheduler=scheduler,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
def _encode_image(
@@ -600,6 +608,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# cast back to fp16 if needed
if needs_upcasting:
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index 1a938aaf9423..6cd0e415e129 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -22,7 +22,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
-from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -31,6 +31,7 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -41,6 +42,14 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
@dataclass
class StableDiffusionAdapterPipelineOutput(BaseOutput):
"""
@@ -59,6 +68,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput):
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -178,7 +188,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
+class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
https://arxiv.org/abs/2302.08453
@@ -208,7 +218,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -259,7 +270,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -914,6 +925,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
image = latents
has_nsfw_concept = None
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 20569d0adb32..5eacb64d01e3 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -43,6 +43,7 @@
from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -53,8 +54,16 @@
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -248,7 +257,8 @@ class StableDiffusionXLAdapterPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -292,9 +302,13 @@ def __init__(
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
@@ -422,7 +436,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -481,8 +497,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1261,6 +1279,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
index cdd72b97f86b..5c63d66e3133 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
@@ -25,6 +25,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -36,8 +37,16 @@
from . import TextToVideoSDPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -105,7 +114,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -627,6 +636,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
index 92bf1d388c13..006c7a79ce0d 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
@@ -26,6 +26,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -37,8 +38,16 @@
from . import TextToVideoSDPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -140,7 +149,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -679,6 +688,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
index c95c7f1b9625..df85f470a80b 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -11,16 +11,30 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
-from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ BaseOutput,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
class TextToVideoZeroPipeline(
- DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
@@ -304,8 +322,8 @@ class TextToVideoZeroPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`CLIPImageProcessor`]):
A [`CLIPImageProcessor`] to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -340,7 +358,7 @@ def __init__(
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def forward_loop(self, x_t0, t0, t1, generator):
@@ -440,6 +458,10 @@ def backward_loop(
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
return latents.clone().detach()
# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
index 9ff473cc3a38..339d5b3a6019 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
@@ -42,6 +42,16 @@
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -409,10 +419,14 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- self.default_sample_size = self.unet.config.sample_size
+ self.default_sample_size = (
+ self.unet.config.sample_size
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
+ else 128
+ )
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -705,7 +719,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -764,8 +780,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -922,6 +940,10 @@ def backward_loop(
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
return latents.clone().detach()
@torch.no_grad()
diff --git a/src/diffusers/pipelines/transformers_loading_utils.py b/src/diffusers/pipelines/transformers_loading_utils.py
new file mode 100644
index 000000000000..b52d154d6ba2
--- /dev/null
+++ b/src/diffusers/pipelines/transformers_loading_utils.py
@@ -0,0 +1,121 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import contextlib
+import os
+import tempfile
+from typing import TYPE_CHECKING, Dict
+
+from huggingface_hub import DDUFEntry
+from tqdm import tqdm
+
+from ..utils import is_safetensors_available, is_transformers_available, is_transformers_version
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel, PreTrainedTokenizer
+
+if is_transformers_available():
+ from transformers import PreTrainedModel, PreTrainedTokenizer
+
+if is_safetensors_available():
+ import safetensors.torch
+
+
+def _load_tokenizer_from_dduf(
+ cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
+) -> "PreTrainedTokenizer":
+ """
+ Load a tokenizer from a DDUF archive.
+
+ In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a
+ workaround by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted
+ files. There is an extra cost of extracting the files, but of limited impact as the tokenizer files are usually
+ small-ish.
+ """
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ for entry_name, entry in dduf_entries.items():
+ if entry_name.startswith(name + "/"):
+ tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/"))
+ # need to create intermediary directory if they don't exist
+ os.makedirs(os.path.dirname(tmp_entry_path), exist_ok=True)
+ with open(tmp_entry_path, "wb") as f:
+ with entry.as_mmap() as mm:
+ f.write(mm)
+ return cls.from_pretrained(os.path.dirname(tmp_entry_path), **kwargs)
+
+
+def _load_transformers_model_from_dduf(
+ cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
+) -> "PreTrainedModel":
+ """
+ Load a transformers model from a DDUF archive.
+
+ In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround
+ by instantiating a model from the config file and loading the weights from the DDUF archive directly.
+ """
+ config_file = dduf_entries.get(f"{name}/config.json")
+ if config_file is None:
+ raise EnvironmentError(
+ f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})."
+ )
+ generation_config = dduf_entries.get(f"{name}/generation_config.json", None)
+
+ weight_files = [
+ entry
+ for entry_name, entry in dduf_entries.items()
+ if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors")
+ ]
+ if not weight_files:
+ raise EnvironmentError(
+ f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})."
+ )
+ if not is_safetensors_available():
+ raise EnvironmentError(
+ "Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`."
+ )
+ if is_transformers_version("<", "4.47.0"):
+ raise ImportError(
+ "You need to install `transformers>4.47.0` in order to load a transformers model from a DDUF file. "
+ "You can install it with: `pip install --upgrade transformers`"
+ )
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ from transformers import AutoConfig, GenerationConfig
+
+ tmp_config_file = os.path.join(tmp_dir, "config.json")
+ with open(tmp_config_file, "w") as f:
+ f.write(config_file.read_text())
+ config = AutoConfig.from_pretrained(tmp_config_file)
+ if generation_config is not None:
+ tmp_generation_config_file = os.path.join(tmp_dir, "generation_config.json")
+ with open(tmp_generation_config_file, "w") as f:
+ f.write(generation_config.read_text())
+ generation_config = GenerationConfig.from_pretrained(tmp_generation_config_file)
+ state_dict = {}
+ with contextlib.ExitStack() as stack:
+ for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files
+ # Memory-map the safetensors file
+ mmap = stack.enter_context(entry.as_mmap())
+ # Load tensors from the memory-mapped file
+ tensors = safetensors.torch.load(mmap)
+ # Update the state dictionary with tensors
+ state_dict.update(tensors)
+ return cls.from_pretrained(
+ pretrained_model_name_or_path=None,
+ config=config,
+ generation_config=generation_config,
+ state_dict=state_dict,
+ **kwargs,
+ )
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py
index 25c6739d8720..bf42d44f74c1 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -22,12 +22,19 @@
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
-from ...utils import logging
+from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -474,6 +481,9 @@ def __call__(
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = super_res_latents
# done super res
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
index 2a0e7e90e4d2..8fa0a848f7e7 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
@@ -27,12 +27,19 @@
from ...models import UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
-from ...utils import logging
+from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -400,6 +407,9 @@ def __call__(
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = super_res_latents
# done super res
diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
index cb1514b153ce..1e285a9670e2 100644
--- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
+++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
@@ -104,8 +104,8 @@ def __init__(
self.use_pos_embed = use_pos_embed
if self.use_pos_embed:
- pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5), output_type="pt")
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False)
def forward(self, latent):
latent = self.proj(latent)
diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
index 4f65caf4e610..66d7404fb9a5 100644
--- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
+++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
@@ -18,7 +18,14 @@
from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -26,6 +33,13 @@
from .modeling_uvit import UniDiffuserModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -117,7 +131,7 @@ def __init__(
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.num_channels_latents = vae.config.latent_channels
@@ -1378,6 +1392,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post-processing
image = None
text = None
diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py
new file mode 100644
index 000000000000..80916a8a1e10
--- /dev/null
+++ b/src/diffusers/pipelines/wan/__init__.py
@@ -0,0 +1,51 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_wan"] = ["WanPipeline"]
+ _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
+ _import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_wan import WanPipeline
+ from .pipeline_wan_i2v import WanImageToVideoPipeline
+ from .pipeline_wan_video2video import WanVideoToVideoPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/wan/pipeline_output.py b/src/diffusers/pipelines/wan/pipeline_output.py
new file mode 100644
index 000000000000..88907ad0f0a1
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class WanPipelineOutput(BaseOutput):
+ r"""
+ Output class for Wan pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py
new file mode 100644
index 000000000000..3294e9a56a07
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan.py
@@ -0,0 +1,593 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import regex as re
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from diffusers import AutoencoderKLWan, WanPipeline
+ >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
+
+ >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
+ >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=720,
+ ... width=1280,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: WanTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ The dtype to use for the torch.amp.autocast.
+
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
new file mode 100644
index 000000000000..fd1d90849a66
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
@@ -0,0 +1,703 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import ftfy
+import PIL
+import regex as re
+import torch
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> import numpy as np
+ >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+ >>> from transformers import CLIPVisionModel
+
+ >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
+ >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
+ >>> image_encoder = CLIPVisionModel.from_pretrained(
+ ... model_id, subfolder="image_encoder", torch_dtype=torch.float32
+ ... )
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = WanImageToVideoPipeline.from_pretrained(
+ ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+ >>> max_area = 480 * 832
+ >>> aspect_ratio = image.height / image.width
+ >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ >>> image = image.resize((width, height))
+ >>> prompt = (
+ ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+ ... )
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=height,
+ ... width=width,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
+ the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ image_encoder: CLIPVisionModel,
+ image_processor: CLIPImageProcessor,
+ transformer: WanTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.image_processor = image_processor
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_image(
+ self,
+ image: PipelineImageInput,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ image = image.unsqueeze(2)
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
+ )
+ video_condition = video_condition.to(device=device, dtype=dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+
+ latent_condition = (latent_condition - latents_mean) * latents_std
+
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
+
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `480`):
+ The height of the generated video.
+ width (`int`, defaults to `832`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+ shift (`float`, *optional*, defaults to `5.0`):
+ The shift of the flow.
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ The dtype to use for the torch.amp.autocast.
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Encode image embedding
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ image_embeds = self.encode_image(image, device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+ latents, condition = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
new file mode 100644
index 000000000000..c72dd7f5f1eb
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
@@ -0,0 +1,725 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import regex as re
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers.utils import export_to_video
+ >>> from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline
+ >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
+
+ >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
+ >>> model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = WanVideoToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A robot standing on a mountain top. The sun is setting in the background"
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
+ ... )
+ >>> output = pipe(
+ ... video=video,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=720,
+ ... guidance_scale=5.0,
+ ... strength=0.7,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for video-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: WanTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ video=None,
+ latents=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` should be provided")
+
+ def prepare_latents(
+ self,
+ video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_latent_frames = (
+ (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
+ )
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if latents is None:
+ if isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
+ ]
+ else:
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, dtype
+ )
+
+ init_latents = (init_latents - latents_mean) * latents_std
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ if hasattr(self.scheduler, "add_noise"):
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ else:
+ latents = self.scheduelr.scale_noise(init_latents, timestep, noise)
+ else:
+ latents = latents.to(device)
+
+ return latents
+
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: List[Image.Image] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 5.0,
+ strength: float = 0.8,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ The dtype to use for the torch.amp.autocast.
+
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ video,
+ latents,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+ self._num_timesteps = len(timesteps)
+
+ if latents is None:
+ video = self.video_processor.preprocess_video(video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ latent_timestep,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
index edb0c1ec45de..9863c506d743 100644
--- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
@@ -29,7 +29,6 @@
AttnProcessor,
)
from ...models.modeling_utils import ModelMixin
-from ...utils import is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
@@ -138,9 +137,6 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor)
- def _set_gradient_checkpointing(self, module, value=False):
- self.gradient_checkpointing = value
-
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
@@ -158,34 +154,14 @@ def forward(self, x, r, c):
c_embed = self.cond_mapper(c)
r_embed = self.gen_r_embedding(r)
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- for block in self.blocks:
- if isinstance(block, AttnBlock):
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, c_embed, use_reentrant=False
- )
- elif isinstance(block, TimestepBlock):
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, r_embed, use_reentrant=False
- )
- else:
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
- else:
- for block in self.blocks:
- if isinstance(block, AttnBlock):
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
- elif isinstance(block, TimestepBlock):
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
- else:
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ if isinstance(block, AttnBlock):
+ x = self._gradient_checkpointing_func(block, x, c_embed)
+ elif isinstance(block, TimestepBlock):
+ x = self._gradient_checkpointing_func(block, x, r_embed)
+ else:
+ x = self._gradient_checkpointing_func(block, x)
else:
for block in self.blocks:
if isinstance(block, AttnBlock):
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
index b08421415b23..edc01f0d5c75 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
@@ -19,15 +19,23 @@
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import deprecate, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -413,6 +421,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
index 92223ce993a6..8f6ba419721d 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
@@ -22,14 +22,22 @@
from ...loaders import StableDiffusionLoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
+from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """
@@ -502,6 +510,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std
diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py
index 93852d29ef59..4c8483a3d6ee 100644
--- a/src/diffusers/quantizers/__init__.py
+++ b/src/diffusers/quantizers/__init__.py
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer
+from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py
index f231f279e13a..ce214ae7bc17 100644
--- a/src/diffusers/quantizers/auto.py
+++ b/src/diffusers/quantizers/auto.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,28 +15,45 @@
Adapted from
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
"""
+
import warnings
from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
-from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
+from .gguf import GGUFQuantizer
+from .quantization_config import (
+ BitsAndBytesConfig,
+ GGUFQuantizationConfig,
+ QuantizationConfigMixin,
+ QuantizationMethod,
+ QuantoConfig,
+ TorchAoConfig,
+)
+from .quanto import QuantoQuantizer
+from .torchao import TorchAoHfQuantizer
AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
+ "gguf": GGUFQuantizer,
+ "quanto": QuantoQuantizer,
+ "torchao": TorchAoHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
+ "gguf": GGUFQuantizationConfig,
+ "quanto": QuantoConfig,
+ "torchao": TorchAoConfig,
}
-class DiffusersAutoQuantizationConfig:
+class DiffusersAutoQuantizer:
"""
- The auto diffusers quantization config class that takes care of automatically dispatching to the correct
- quantization config given a quantization config stored in a dictionary.
+ The auto diffusers quantizer class that takes care of automatically instantiating to the correct
+ `DiffusersQuantizer` given the `QuantizationConfig`.
"""
@classmethod
@@ -60,31 +77,11 @@ def from_dict(cls, quantization_config_dict: Dict):
target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
return target_cls.from_dict(quantization_config_dict)
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
- model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
- if getattr(model_config, "quantization_config", None) is None:
- raise ValueError(
- f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
- )
- quantization_config_dict = model_config.quantization_config
- quantization_config = cls.from_dict(quantization_config_dict)
- # Update with potential kwargs that are passed through from_pretrained.
- quantization_config.update(kwargs)
- return quantization_config
-
-
-class DiffusersAutoQuantizer:
- """
- The auto diffusers quantizer class that takes care of automatically instantiating to the correct
- `DiffusersQuantizer` given the `QuantizationConfig`.
- """
-
@classmethod
def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
# Convert it to a QuantizationConfig if the q_config is a dict
if isinstance(quantization_config, dict):
- quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
+ quantization_config = cls.from_dict(quantization_config)
quant_method = quantization_config.quant_method
@@ -107,7 +104,16 @@ def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict],
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
- quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
+ if getattr(model_config, "quantization_config", None) is None:
+ raise ValueError(
+ f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
+ )
+ quantization_config_dict = model_config.quantization_config
+ quantization_config = cls.from_dict(quantization_config_dict)
+ # Update with potential kwargs that are passed through from_pretrained.
+ quantization_config.update(kwargs)
+
return cls.from_config(quantization_config)
@classmethod
@@ -129,7 +135,7 @@ def merge_quantization_configs(
warning_msg = ""
if isinstance(quantization_config, dict):
- quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
+ quantization_config = cls.from_dict(quantization_config)
if warning_msg != "":
warnings.warn(warning_msg)
diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py
index 017136a98854..1c75b5bef933 100644
--- a/src/diffusers/quantizers/base.py
+++ b/src/diffusers/quantizers/base.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -134,7 +134,7 @@ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str,
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
return max_memory
- def check_quantized_param(
+ def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
@@ -152,10 +152,13 @@ def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
"""
takes needed components from state_dict and creates quantized param.
"""
- if not hasattr(self, "check_quantized_param"):
- raise AttributeError(
- f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
- )
+ return
+
+ def check_quantized_param_shape(self, *args, **kwargs):
+ """
+ checks if the quantized param has expected shape.
+ """
+ return True
def validate_environment(self, *args, **kwargs):
"""
diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
index e3041aba60ae..689d8e4256c2 100644
--- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
+++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,7 +61,7 @@ def __init__(self, quantization_config, **kwargs):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
def validate_environment(self, *args, **kwargs):
- if not torch.cuda.is_available():
+ if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
@@ -106,7 +106,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
else:
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
- def check_quantized_param(
+ def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
@@ -135,6 +135,7 @@ def create_quantized_param(
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
+ **kwargs,
):
import bitsandbytes as bnb
@@ -204,6 +205,19 @@ def create_quantized_param(
module._parameters[tensor_name] = new_value
+ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
+ current_param_shape = current_param.shape
+ loaded_param_shape = loaded_param.shape
+
+ n = current_param_shape.numel()
+ inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
+ if loaded_param_shape != inferred_shape:
+ raise ValueError(
+ f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}."
+ )
+ else:
+ return True
+
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
# need more space for buffers that are created during quantization
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
@@ -222,18 +236,20 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
torch_dtype = torch.float16
return torch_dtype
- # (sayakpaul): I think it could be better to disable custom `device_map`s
- # for the first phase of the integration in the interest of simplicity.
- # Commenting this for discussions on the PR.
- # def update_device_map(self, device_map):
- # if device_map is None:
- # device_map = {"": torch.cuda.current_device()}
- # logger.info(
- # "The device_map was not initialized. "
- # "Setting device_map to {'':torch.cuda.current_device()}. "
- # "If you want to use the model for inference, please set device_map ='auto' "
- # )
- # return device_map
+ def update_device_map(self, device_map):
+ if device_map is None:
+ if torch.xpu.is_available():
+ current_device = f"xpu:{torch.xpu.current_device()}"
+ else:
+ current_device = f"cuda:{torch.cuda.current_device()}"
+ device_map = {"": current_device}
+ logger.info(
+ "The device_map was not initialized. "
+ "Setting device_map to {"
+ ": {current_device}}. "
+ "If you want to use the model for inference, please set device_map ='auto' "
+ )
+ return device_map
def _process_model_before_weight_loading(
self,
@@ -276,9 +292,9 @@ def _process_model_before_weight_loading(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
+ model.is_loaded_in_4bit = True
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
- model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable
return model
@@ -300,7 +316,10 @@ def _dequantize(self, model):
logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
)
- model.to(torch.cuda.current_device())
+ if torch.xpu.is_available():
+ model.to(torch.xpu.current_device())
+ else:
+ model.to(torch.cuda.current_device())
model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
@@ -330,9 +349,8 @@ def __init__(self, quantization_config, **kwargs):
if self.quantization_config.llm_int8_skip_modules is not None:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
- # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit
def validate_environment(self, *args, **kwargs):
- if not torch.cuda.is_available():
+ if not (torch.cuda.is_available() or torch.xpu.is_available()):
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
@@ -388,23 +406,28 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
torch_dtype = torch.float16
return torch_dtype
- # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
- # def update_device_map(self, device_map):
- # if device_map is None:
- # device_map = {"": torch.cuda.current_device()}
- # logger.info(
- # "The device_map was not initialized. "
- # "Setting device_map to {'':torch.cuda.current_device()}. "
- # "If you want to use the model for inference, please set device_map ='auto' "
- # )
- # return device_map
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
+ def update_device_map(self, device_map):
+ if device_map is None:
+ if torch.xpu.is_available():
+ current_device = f"xpu:{torch.xpu.current_device()}"
+ else:
+ current_device = f"cuda:{torch.cuda.current_device()}"
+ device_map = {"": current_device}
+ logger.info(
+ "The device_map was not initialized. "
+ "Setting device_map to {"
+ ": {current_device}}. "
+ "If you want to use the model for inference, please set device_map ='auto' "
+ )
+ return device_map
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if target_dtype != torch.int8:
logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
return torch.int8
- def check_quantized_param(
+ def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
@@ -434,6 +457,7 @@ def create_quantized_param(
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
+ **kwargs,
):
import bitsandbytes as bnb
@@ -481,11 +505,10 @@ def create_quantized_param(
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
- model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable
return model
- # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
@@ -527,6 +550,7 @@ def _process_model_before_weight_loading(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
+ model.is_loaded_in_8bit = True
@property
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py
index 03755db3d1ec..a9771b368a86 100644
--- a/src/diffusers/quantizers/bitsandbytes/utils.py
+++ b/src/diffusers/quantizers/bitsandbytes/utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
return model
-# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
-def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
+# Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81
+def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
"""
Helper function to dequantize 4bit or 8bit bnb weights.
@@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
if state.SCB is None:
state.SCB = weight.SCB
- im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
- im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
- im, Sim = bnb.functional.transform(im, "col32")
- if state.CxB is None:
- state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
- out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
- return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
+ if hasattr(bnb.functional, "int8_vectorwise_dequant"):
+ # Use bitsandbytes API if available (requires v0.45.0+)
+ dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
+ else:
+ # Multiply by (scale/127) to dequantize.
+ dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
+
+ if dtype:
+ dequantized = dequantized.to(dtype)
+ return dequantized
def _create_accelerate_new_hook(old_hook):
@@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):
def _dequantize_and_replace(
model,
+ dtype,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
@@ -244,7 +248,7 @@ def _dequantize_and_replace(
else:
state = None
- new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
+ new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))
if bias is not None:
new_module.bias = bias
@@ -263,9 +267,10 @@ def _dequantize_and_replace(
if len(list(module.children())) > 0:
_, has_been_replaced = _dequantize_and_replace(
module,
- modules_to_not_convert,
- current_key_name,
- quantization_config,
+ dtype=dtype,
+ modules_to_not_convert=modules_to_not_convert,
+ current_key_name=current_key_name,
+ quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
@@ -280,6 +285,7 @@ def dequantize_and_replace(
):
model, has_been_replaced = _dequantize_and_replace(
model,
+ dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
diff --git a/src/diffusers/quantizers/gguf/__init__.py b/src/diffusers/quantizers/gguf/__init__.py
new file mode 100644
index 000000000000..b3d9082ac803
--- /dev/null
+++ b/src/diffusers/quantizers/gguf/__init__.py
@@ -0,0 +1 @@
+from .gguf_quantizer import GGUFQuantizer
diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py
new file mode 100644
index 000000000000..6da69c7bd60c
--- /dev/null
+++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py
@@ -0,0 +1,160 @@
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from ..base import DiffusersQuantizer
+
+
+if TYPE_CHECKING:
+ from ...models.modeling_utils import ModelMixin
+
+
+from ...utils import (
+ get_module_from_name,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_gguf_available,
+ is_gguf_version,
+ is_torch_available,
+ logging,
+)
+
+
+if is_torch_available() and is_gguf_available():
+ import torch
+
+ from .utils import (
+ GGML_QUANT_SIZES,
+ GGUFParameter,
+ _dequantize_gguf_and_restore_linear,
+ _quant_shape_from_byte_shape,
+ _replace_with_gguf_linear,
+ )
+
+
+logger = logging.get_logger(__name__)
+
+
+class GGUFQuantizer(DiffusersQuantizer):
+ use_keep_in_fp32_modules = True
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ self.compute_dtype = quantization_config.compute_dtype
+ self.pre_quantized = quantization_config.pre_quantized
+ self.modules_to_not_convert = quantization_config.modules_to_not_convert
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ def validate_environment(self, *args, **kwargs):
+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
+ raise ImportError(
+ "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
+ )
+ if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
+ raise ImportError(
+ "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
+ )
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ # need more space for buffers that are created during quantization
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if target_dtype != torch.uint8:
+ logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization")
+ return torch.uint8
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ torch_dtype = self.compute_dtype
+ return torch_dtype
+
+ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
+ loaded_param_shape = loaded_param.shape
+ current_param_shape = current_param.shape
+ quant_type = loaded_param.quant_type
+
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
+
+ inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
+ if inferred_shape != current_param_shape:
+ raise ValueError(
+ f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
+ )
+
+ return True
+
+ def check_if_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: Union["GGUFParameter", "torch.Tensor"],
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ) -> bool:
+ if isinstance(param_value, GGUFParameter):
+ return True
+
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: Union["GGUFParameter", "torch.Tensor"],
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Optional[Dict[str, Any]] = None,
+ unexpected_keys: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ module, tensor_name = get_module_from_name(model, param_name)
+ if tensor_name not in module._parameters and tensor_name not in module._buffers:
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
+
+ if tensor_name in module._parameters:
+ module._parameters[tensor_name] = param_value.to(target_device)
+ if tensor_name in module._buffers:
+ module._buffers[tensor_name] = param_value.to(target_device)
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ state_dict = kwargs.get("state_dict", None)
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
+
+ _replace_with_gguf_linear(
+ model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert
+ )
+
+ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
+ return model
+
+ @property
+ def is_serializable(self):
+ return False
+
+ @property
+ def is_trainable(self) -> bool:
+ return False
+
+ def _dequantize(self, model):
+ is_model_on_cpu = model.device.type == "cpu"
+ if is_model_on_cpu:
+ logger.info(
+ "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
+ )
+ model.to(torch.cuda.current_device())
+
+ model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
+ if is_model_on_cpu:
+ model.to("cpu")
+ return model
diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py
new file mode 100644
index 000000000000..effc39d8fe97
--- /dev/null
+++ b/src/diffusers/quantizers/gguf/utils.py
@@ -0,0 +1,456 @@
+# Copyright 2024 The HuggingFace Team and City96. All rights reserved.
+# #
+# # Licensed under the Apache License, Version 2.0 (the "License");
+# # you may not use this file except in compliance with the License.
+# # You may obtain a copy of the License at
+# #
+# # http://www.apache.org/licenses/LICENSE-2.0
+# #
+# # Unless required by applicable law or agreed to in writing, software
+# # distributed under the License is distributed on an "AS IS" BASIS,
+# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# # See the License for the specific language governing permissions and
+# # limitations under the License.
+
+
+import inspect
+from contextlib import nullcontext
+
+import gguf
+import torch
+import torch.nn as nn
+
+from ...utils import is_accelerate_available
+
+
+if is_accelerate_available():
+ import accelerate
+ from accelerate import init_empty_weights
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
+
+
+# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
+def _create_accelerate_new_hook(old_hook):
+ r"""
+ Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
+ https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
+ some changes
+ """
+ old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
+ old_hook_attr = old_hook.__dict__
+ filtered_old_hook_attr = {}
+ old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
+ for k in old_hook_attr.keys():
+ if k in old_hook_init_signature.parameters:
+ filtered_old_hook_attr[k] = old_hook_attr[k]
+ new_hook = old_hook_cls(**filtered_old_hook_attr)
+ return new_hook
+
+
+def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]):
+ def _should_convert_to_gguf(state_dict, prefix):
+ weight_key = prefix + "weight"
+ return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter)
+
+ has_children = list(model.children())
+ if not has_children:
+ return
+
+ for name, module in model.named_children():
+ module_prefix = prefix + name + "."
+ _replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert)
+
+ if (
+ isinstance(module, nn.Linear)
+ and _should_convert_to_gguf(state_dict, module_prefix)
+ and name not in modules_to_not_convert
+ ):
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ model._modules[name] = GGUFLinear(
+ module.in_features,
+ module.out_features,
+ module.bias is not None,
+ compute_dtype=compute_dtype,
+ )
+ model._modules[name].source_cls = type(module)
+ # Force requires_grad to False to avoid unexpected errors
+ model._modules[name].requires_grad_(False)
+
+ return model
+
+
+def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]):
+ for name, module in model.named_children():
+ if isinstance(module, GGUFLinear) and name not in modules_to_not_convert:
+ device = module.weight.device
+ bias = getattr(module, "bias", None)
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ new_module = nn.Linear(
+ module.in_features,
+ module.out_features,
+ module.bias is not None,
+ device=device,
+ )
+ new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight))
+ if bias is not None:
+ new_module.bias = bias
+
+ # Create a new hook and attach it in case we use accelerate
+ if hasattr(module, "_hf_hook"):
+ old_hook = module._hf_hook
+ new_hook = _create_accelerate_new_hook(old_hook)
+
+ remove_hook_from_module(module)
+ add_hook_to_module(new_module, new_hook)
+
+ new_module.to(device)
+ model._modules[name] = new_module
+
+ has_children = list(module.children())
+ if has_children:
+ _dequantize_gguf_and_restore_linear(module, modules_to_not_convert)
+
+ return model
+
+
+# dequantize operations based on torch ports of GGUF dequantize_functions
+# from City96
+# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py
+
+
+QK_K = 256
+K_SCALE_SIZE = 12
+
+
+def to_uint32(x):
+ x = x.view(torch.uint8).to(torch.int32)
+ return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
+
+
+def split_block_dims(blocks, *args):
+ n_max = blocks.shape[1]
+ dims = list(args) + [n_max - sum(args)]
+ return torch.split(blocks, dims, dim=1)
+
+
+def get_scale_min(scales):
+ n_blocks = scales.shape[0]
+ scales = scales.view(torch.uint8)
+ scales = scales.reshape((n_blocks, 3, 4))
+
+ d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
+
+ sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
+ min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
+
+ return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
+
+
+def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
+ d, x = split_block_dims(blocks, 2)
+ d = d.view(torch.float16).to(dtype)
+ x = x.view(torch.int8)
+ return d * x
+
+
+def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
+ d = d.view(torch.float16).to(dtype)
+ m = m.view(torch.float16).to(dtype)
+ qh = to_uint32(qh)
+
+ qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
+ ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
+ [0, 4], device=d.device, dtype=torch.uint8
+ ).reshape(1, 1, 2, 1)
+ qh = (qh & 1).to(torch.uint8)
+ ql = (ql & 0x0F).reshape((n_blocks, -1))
+
+ qs = ql | (qh << 4)
+ return (d * qs) + m
+
+
+def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ d, qh, qs = split_block_dims(blocks, 2, 4)
+ d = d.view(torch.float16).to(dtype)
+ qh = to_uint32(qh)
+
+ qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
+ ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
+ [0, 4], device=d.device, dtype=torch.uint8
+ ).reshape(1, 1, 2, 1)
+
+ qh = (qh & 1).to(torch.uint8)
+ ql = (ql & 0x0F).reshape(n_blocks, -1)
+
+ qs = (ql | (qh << 4)).to(torch.int8) - 16
+ return d * qs
+
+
+def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ d, m, qs = split_block_dims(blocks, 2, 2)
+ d = d.view(torch.float16).to(dtype)
+ m = m.view(torch.float16).to(dtype)
+
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
+ [0, 4], device=d.device, dtype=torch.uint8
+ ).reshape(1, 1, 2, 1)
+ qs = (qs & 0x0F).reshape(n_blocks, -1)
+
+ return (d * qs) + m
+
+
+def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ d, qs = split_block_dims(blocks, 2)
+ d = d.view(torch.float16).to(dtype)
+
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
+ [0, 4], device=d.device, dtype=torch.uint8
+ ).reshape((1, 1, 2, 1))
+ qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
+ return d * qs
+
+
+def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ (
+ ql,
+ qh,
+ scales,
+ d,
+ ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
+
+ scales = scales.view(torch.int8).to(dtype)
+ d = d.view(torch.float16).to(dtype)
+ d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
+
+ ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
+ (1, 1, 2, 1)
+ )
+ ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
+ (1, 1, 4, 1)
+ )
+ qh = (qh & 0x03).reshape((n_blocks, -1, 32))
+ q = (ql | (qh << 4)).to(torch.int8) - 32
+ q = q.reshape((n_blocks, QK_K // 16, -1))
+
+ return (d * q).reshape((n_blocks, QK_K))
+
+
+def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
+
+ d = d.view(torch.float16).to(dtype)
+ dmin = dmin.view(torch.float16).to(dtype)
+
+ sc, m = get_scale_min(scales)
+
+ d = (d * sc).reshape((n_blocks, -1, 1))
+ dm = (dmin * m).reshape((n_blocks, -1, 1))
+
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
+ (1, 1, 2, 1)
+ )
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
+ (1, 1, 8, 1)
+ )
+ ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
+ qh = (qh & 0x01).reshape((n_blocks, -1, 32))
+ q = ql | (qh << 4)
+
+ return (d * q - dm).reshape((n_blocks, QK_K))
+
+
+def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
+ d = d.view(torch.float16).to(dtype)
+ dmin = dmin.view(torch.float16).to(dtype)
+
+ sc, m = get_scale_min(scales)
+
+ d = (d * sc).reshape((n_blocks, -1, 1))
+ dm = (dmin * m).reshape((n_blocks, -1, 1))
+
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
+ (1, 1, 2, 1)
+ )
+ qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
+
+ return (d * qs - dm).reshape((n_blocks, QK_K))
+
+
+def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
+ d = d.view(torch.float16).to(dtype)
+
+ lscales, hscales = scales[:, :8], scales[:, 8:]
+ lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
+ (1, 2, 1)
+ )
+ lscales = lscales.reshape((n_blocks, 16))
+ hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
+ [0, 2, 4, 6], device=d.device, dtype=torch.uint8
+ ).reshape((1, 4, 1))
+ hscales = hscales.reshape((n_blocks, 16))
+ scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
+ scales = scales.to(torch.int8) - 32
+
+ dl = (d * scales).reshape((n_blocks, 16, 1))
+
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
+ (1, 1, 4, 1)
+ )
+ qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
+ (1, 1, 8, 1)
+ )
+ ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
+ qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
+ q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
+
+ return (dl * q).reshape((n_blocks, QK_K))
+
+
+def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
+ n_blocks = blocks.shape[0]
+
+ scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
+ d = d.view(torch.float16).to(dtype)
+ dmin = dmin.view(torch.float16).to(dtype)
+
+ # (n_blocks, 16, 1)
+ dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
+ ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
+
+ shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
+
+ qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
+ qs = qs.reshape((n_blocks, QK_K // 16, 16))
+ qs = dl * qs - ml
+
+ return qs.reshape((n_blocks, -1))
+
+
+def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
+ return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
+
+
+GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
+dequantize_functions = {
+ gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
+ gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
+ gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
+ gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
+ gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
+ gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
+ gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
+ gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
+ gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
+ gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
+ gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
+}
+SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys())
+
+
+def _quant_shape_from_byte_shape(shape, type_size, block_size):
+ return (*shape[:-1], shape[-1] // type_size * block_size)
+
+
+def dequantize_gguf_tensor(tensor):
+ if not hasattr(tensor, "quant_type"):
+ return tensor
+
+ quant_type = tensor.quant_type
+ dequant_fn = dequantize_functions[quant_type]
+
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
+
+ tensor = tensor.view(torch.uint8)
+ shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
+
+ n_blocks = tensor.numel() // type_size
+ blocks = tensor.reshape((n_blocks, type_size))
+
+ dequant = dequant_fn(blocks, block_size, type_size)
+ dequant = dequant.reshape(shape)
+
+ return dequant.as_tensor()
+
+
+class GGUFParameter(torch.nn.Parameter):
+ def __new__(cls, data, requires_grad=False, quant_type=None):
+ data = data if data is not None else torch.empty(0)
+ self = torch.Tensor._make_subclass(cls, data, requires_grad)
+ self.quant_type = quant_type
+
+ return self
+
+ def as_tensor(self):
+ return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+
+ result = super().__torch_function__(func, types, args, kwargs)
+
+ # When converting from original format checkpoints we often use splits, cats etc on tensors
+ # this method ensures that the returned tensor type from those operations remains GGUFParameter
+ # so that we preserve quant_type information
+ quant_type = None
+ for arg in args:
+ if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
+ quant_type = arg[0].quant_type
+ break
+ if isinstance(arg, GGUFParameter):
+ quant_type = arg.quant_type
+ break
+ if isinstance(result, torch.Tensor):
+ return cls(result, quant_type=quant_type)
+ # Handle tuples and lists
+ elif isinstance(result, (tuple, list)):
+ # Preserve the original type (tuple or list)
+ wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
+ return type(result)(wrapped)
+ else:
+ return result
+
+
+class GGUFLinear(nn.Linear):
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ bias=False,
+ compute_dtype=None,
+ device=None,
+ ) -> None:
+ super().__init__(in_features, out_features, bias, device)
+ self.compute_dtype = compute_dtype
+
+ def forward(self, inputs):
+ weight = dequantize_gguf_tensor(self.weight)
+ weight = weight.to(self.compute_dtype)
+ bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
+
+ output = torch.nn.functional.linear(inputs, weight, bias)
+ return output
diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py
index f521c5d717d6..0bc433be0ff3 100644
--- a/src/diffusers/quantizers/quantization_config.py
+++ b/src/diffusers/quantizers/quantization_config.py
@@ -22,15 +22,17 @@
import copy
import importlib.metadata
+import inspect
import json
import os
from dataclasses import dataclass
from enum import Enum
-from typing import Any, Dict, Union
+from functools import partial
+from typing import Any, Dict, List, Optional, Union
from packaging import version
-from ..utils import is_torch_available, logging
+from ..utils import is_torch_available, is_torchao_available, logging
if is_torch_available():
@@ -41,6 +43,19 @@
class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes"
+ GGUF = "gguf"
+ TORCHAO = "torchao"
+ QUANTO = "quanto"
+
+
+if is_torchao_available():
+ from torchao.quantization.quant_primitives import MappingType
+
+ class TorchAoJSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, MappingType):
+ return obj.name
+ return super().default(obj)
@dataclass
@@ -389,3 +404,321 @@ def to_diff_dict(self) -> Dict[str, Any]:
serializable_config_dict[key] = value
return serializable_config_dict
+
+
+@dataclass
+class GGUFQuantizationConfig(QuantizationConfigMixin):
+ """This is a config class for GGUF Quantization techniques.
+
+ Args:
+ compute_dtype: (`torch.dtype`, defaults to `torch.float32`):
+ This sets the computational type which might be different than the input type. For example, inputs might be
+ fp32, but computation can be set to bf16 for speedups.
+
+ """
+
+ def __init__(self, compute_dtype: Optional["torch.dtype"] = None):
+ self.quant_method = QuantizationMethod.GGUF
+ self.compute_dtype = compute_dtype
+ self.pre_quantized = True
+
+ # TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints.
+ self.modules_to_not_convert = None
+
+ if self.compute_dtype is None:
+ self.compute_dtype = torch.float32
+
+
+@dataclass
+class TorchAoConfig(QuantizationConfigMixin):
+ """This is a config class for torchao quantization/sparsity techniques.
+
+ Args:
+ quant_type (`str`):
+ The type of quantization we want to use, currently supporting:
+ - **Integer quantization:**
+ - Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
+ `int8_weight_only`, `int8_dynamic_activation_int8_weight`
+ - Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
+
+ - **Floating point 8-bit quantization:**
+ - Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`,
+ `float8_static_activation_float8_weight`
+ - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
+ `float8_e4m3_tensor`, `float8_e4m3_row`,
+
+ - **Floating point X-bit quantization:**
+ - Full function names: `fpx_weight_only`
+ - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
+ of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
+ be satisfied for a given shorthand notation.
+
+ - **Unsigned Integer quantization:**
+ - Full function names: `uintx_weight_only`
+ - Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
+ modules_to_not_convert (`List[str]`, *optional*, default to `None`):
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have some
+ modules left in their original precision.
+ kwargs (`Dict[str, Any]`, *optional*):
+ The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
+ supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
+ documentation of arguments can be found in
+ https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
+
+ Example:
+ ```python
+ from diffusers import FluxTransformer2DModel, TorchAoConfig
+
+ quantization_config = TorchAoConfig("int8wo")
+ transformer = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/Flux.1-Dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ )
+ ```
+ """
+
+ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None:
+ self.quant_method = QuantizationMethod.TORCHAO
+ self.quant_type = quant_type
+ self.modules_to_not_convert = modules_to_not_convert
+
+ # When we load from serialized config, "quant_type_kwargs" will be the key
+ if "quant_type_kwargs" in kwargs:
+ self.quant_type_kwargs = kwargs["quant_type_kwargs"]
+ else:
+ self.quant_type_kwargs = kwargs
+
+ TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
+ if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
+ is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
+ if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
+ f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
+ )
+
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
+ f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
+ signature = inspect.signature(method)
+ all_kwargs = {
+ param.name
+ for param in signature.parameters.values()
+ if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
+ }
+ unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
+
+ if len(unsupported_kwargs) > 0:
+ raise ValueError(
+ f'The quantization method "{quant_type}" does not support the following keyword arguments: '
+ f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
+ )
+
+ @classmethod
+ def _get_torchao_quant_type_to_method(cls):
+ r"""
+ Returns supported torchao quantization types with all commonly used notations.
+ """
+
+ if is_torchao_available():
+ # TODO(aryan): Support autoquant and sparsify
+ from torchao.quantization import (
+ float8_dynamic_activation_float8_weight,
+ float8_static_activation_float8_weight,
+ float8_weight_only,
+ fpx_weight_only,
+ int4_weight_only,
+ int8_dynamic_activation_int4_weight,
+ int8_dynamic_activation_int8_weight,
+ int8_weight_only,
+ uintx_weight_only,
+ )
+
+ # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
+ from torchao.quantization.observer import PerRow, PerTensor
+
+ def generate_float8dq_types(dtype: torch.dtype):
+ name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
+ types = {}
+
+ for granularity_cls in [PerTensor, PerRow]:
+ # Note: Activation and Weights cannot have different granularities
+ granularity_name = "tensor" if granularity_cls is PerTensor else "row"
+ types[f"float8dq_{name}_{granularity_name}"] = partial(
+ float8_dynamic_activation_float8_weight,
+ activation_dtype=dtype,
+ weight_dtype=dtype,
+ granularity=(granularity_cls(), granularity_cls()),
+ )
+
+ return types
+
+ def generate_fpx_quantization_types(bits: int):
+ types = {}
+
+ for ebits in range(1, bits):
+ mbits = bits - ebits - 1
+ types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
+
+ non_sign_bits = bits - 1
+ default_ebits = (non_sign_bits + 1) // 2
+ default_mbits = non_sign_bits - default_ebits
+ types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
+
+ return types
+
+ INT4_QUANTIZATION_TYPES = {
+ # int4 weight + bfloat16/float16 activation
+ "int4wo": int4_weight_only,
+ "int4_weight_only": int4_weight_only,
+ # int4 weight + int8 activation
+ "int4dq": int8_dynamic_activation_int4_weight,
+ "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
+ }
+
+ INT8_QUANTIZATION_TYPES = {
+ # int8 weight + bfloat16/float16 activation
+ "int8wo": int8_weight_only,
+ "int8_weight_only": int8_weight_only,
+ # int8 weight + int8 activation
+ "int8dq": int8_dynamic_activation_int8_weight,
+ "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
+ }
+
+ # TODO(aryan): handle torch 2.2/2.3
+ FLOATX_QUANTIZATION_TYPES = {
+ # float8_e5m2 weight + bfloat16/float16 activation
+ "float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
+ "float8_weight_only": float8_weight_only,
+ "float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
+ # float8_e4m3 weight + bfloat16/float16 activation
+ "float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
+ # float8_e5m2 weight + float8 activation (dynamic)
+ "float8dq": float8_dynamic_activation_float8_weight,
+ "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
+ # ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out.
+ # However, changing activation_dtype=torch.float8_e4m3 might work here =====
+ # "float8dq_e5m2": partial(
+ # float8_dynamic_activation_float8_weight,
+ # activation_dtype=torch.float8_e5m2,
+ # weight_dtype=torch.float8_e5m2,
+ # ),
+ # **generate_float8dq_types(torch.float8_e5m2),
+ # ===== =====
+ # float8_e4m3 weight + float8 activation (dynamic)
+ "float8dq_e4m3": partial(
+ float8_dynamic_activation_float8_weight,
+ activation_dtype=torch.float8_e4m3fn,
+ weight_dtype=torch.float8_e4m3fn,
+ ),
+ **generate_float8dq_types(torch.float8_e4m3fn),
+ # float8 weight + float8 activation (static)
+ "float8_static_activation_float8_weight": float8_static_activation_float8_weight,
+ # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
+ # fpx weight + bfloat16/float16 activation
+ **generate_fpx_quantization_types(3),
+ **generate_fpx_quantization_types(4),
+ **generate_fpx_quantization_types(5),
+ **generate_fpx_quantization_types(6),
+ **generate_fpx_quantization_types(7),
+ }
+
+ UINTX_QUANTIZATION_DTYPES = {
+ "uintx_weight_only": uintx_weight_only,
+ "uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
+ "uint2wo": partial(uintx_weight_only, dtype=torch.uint2),
+ "uint3wo": partial(uintx_weight_only, dtype=torch.uint3),
+ "uint4wo": partial(uintx_weight_only, dtype=torch.uint4),
+ "uint5wo": partial(uintx_weight_only, dtype=torch.uint5),
+ "uint6wo": partial(uintx_weight_only, dtype=torch.uint6),
+ "uint7wo": partial(uintx_weight_only, dtype=torch.uint7),
+ # "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
+ }
+
+ QUANTIZATION_TYPES = {}
+ QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
+ QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
+ QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
+
+ if cls._is_cuda_capability_atleast_8_9():
+ QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
+
+ return QUANTIZATION_TYPES
+ else:
+ raise ValueError(
+ "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
+ )
+
+ @staticmethod
+ def _is_cuda_capability_atleast_8_9() -> bool:
+ if not torch.cuda.is_available():
+ raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
+
+ major, minor = torch.cuda.get_device_capability()
+ if major == 8:
+ return minor >= 9
+ return major >= 9
+
+ def get_apply_tensor_subclass(self):
+ TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
+ return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
+
+ def __repr__(self):
+ r"""
+ Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
+
+ ```
+ TorchAoConfig {
+ "modules_to_not_convert": null,
+ "quant_method": "torchao",
+ "quant_type": "uint4wo",
+ "quant_type_kwargs": {
+ "group_size": 32
+ }
+ }
+ ```
+ """
+ config_dict = self.to_dict()
+ return (
+ f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
+ )
+
+
+@dataclass
+class QuantoConfig(QuantizationConfigMixin):
+ """
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
+ loaded using `quanto`.
+
+ Args:
+ weights_dtype (`str`, *optional*, defaults to `"int8"`):
+ The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
+ modules_to_not_convert (`list`, *optional*, default to `None`):
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have some
+ modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
+ """
+
+ def __init__(
+ self,
+ weights_dtype: str = "int8",
+ modules_to_not_convert: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ self.quant_method = QuantizationMethod.QUANTO
+ self.weights_dtype = weights_dtype
+ self.modules_to_not_convert = modules_to_not_convert
+
+ self.post_init()
+
+ def post_init(self):
+ r"""
+ Safety checker that arguments are correct
+ """
+ accepted_weights = ["float8", "int8", "int4", "int2"]
+ if self.weights_dtype not in accepted_weights:
+ raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
diff --git a/src/diffusers/quantizers/quanto/__init__.py b/src/diffusers/quantizers/quanto/__init__.py
new file mode 100644
index 000000000000..a4e8a1f41a1e
--- /dev/null
+++ b/src/diffusers/quantizers/quanto/__init__.py
@@ -0,0 +1 @@
+from .quanto_quantizer import QuantoQuantizer
diff --git a/src/diffusers/quantizers/quanto/quanto_quantizer.py b/src/diffusers/quantizers/quanto/quanto_quantizer.py
new file mode 100644
index 000000000000..0120163804c9
--- /dev/null
+++ b/src/diffusers/quantizers/quanto/quanto_quantizer.py
@@ -0,0 +1,177 @@
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from diffusers.utils.import_utils import is_optimum_quanto_version
+
+from ...utils import (
+ get_module_from_name,
+ is_accelerate_available,
+ is_accelerate_version,
+ is_optimum_quanto_available,
+ is_torch_available,
+ logging,
+)
+from ..base import DiffusersQuantizer
+
+
+if TYPE_CHECKING:
+ from ...models.modeling_utils import ModelMixin
+
+
+if is_torch_available():
+ import torch
+
+if is_accelerate_available():
+ from accelerate.utils import CustomDtype, set_module_tensor_to_device
+
+if is_optimum_quanto_available():
+ from .utils import _replace_with_quanto_layers
+
+logger = logging.get_logger(__name__)
+
+
+class QuantoQuantizer(DiffusersQuantizer):
+ r"""
+ Diffusers Quantizer for Optimum Quanto
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_calibration = False
+ required_packages = ["quanto", "accelerate"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ def validate_environment(self, *args, **kwargs):
+ if not is_optimum_quanto_available():
+ raise ImportError(
+ "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
+ )
+ if not is_optimum_quanto_version(">=", "0.2.6"):
+ raise ImportError(
+ "Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. "
+ "Please upgrade your installation with `pip install --upgrade optimum-quanto"
+ )
+
+ if not is_accelerate_available():
+ raise ImportError(
+ "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
+ )
+
+ device_map = kwargs.get("device_map", None)
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
+ raise ValueError(
+ "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
+ )
+
+ def check_if_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ):
+ # Quanto imports diffusers internally. This is here to prevent circular imports
+ from optimum.quanto import QModuleMixin, QTensor
+ from optimum.quanto.tensor.packed import PackedTensor
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]):
+ return True
+ elif isinstance(module, QModuleMixin) and "weight" in tensor_name:
+ return not module.frozen
+
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ *args,
+ **kwargs,
+ ):
+ """
+ Create the quantized parameter by calling .freeze() after setting it to the module.
+ """
+
+ dtype = kwargs.get("dtype", torch.float32)
+ module, tensor_name = get_module_from_name(model, param_name)
+ if self.pre_quantized:
+ setattr(module, tensor_name, param_value)
+ else:
+ set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
+ module.freeze()
+ module.weight.requires_grad = False
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if is_accelerate_version(">=", "0.27.0"):
+ mapping = {
+ "int8": torch.int8,
+ "float8": CustomDtype.FP8,
+ "int4": CustomDtype.INT4,
+ "int2": CustomDtype.INT2,
+ }
+ target_dtype = mapping[self.quantization_config.weights_dtype]
+
+ return target_dtype
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
+ if torch_dtype is None:
+ logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
+ torch_dtype = torch.float32
+ return torch_dtype
+
+ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
+ # Quanto imports diffusers internally. This is here to prevent circular imports
+ from optimum.quanto import QModuleMixin
+
+ not_missing_keys = []
+ for name, module in model.named_modules():
+ if isinstance(module, QModuleMixin):
+ for missing in missing_keys:
+ if (
+ (name in missing or name in f"{prefix}.{missing}")
+ and not missing.endswith(".weight")
+ and not missing.endswith(".bias")
+ ):
+ not_missing_keys.append(missing)
+ return [k for k in missing_keys if k not in not_missing_keys]
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+
+ model = _replace_with_quanto_layers(
+ model,
+ modules_to_not_convert=self.modules_to_not_convert,
+ quantization_config=self.quantization_config,
+ pre_quantized=self.pre_quantized,
+ )
+ model.config.quantization_config = self.quantization_config
+
+ def _process_model_after_weight_loading(self, model, **kwargs):
+ return model
+
+ @property
+ def is_trainable(self):
+ return True
+
+ @property
+ def is_serializable(self):
+ return True
diff --git a/src/diffusers/quantizers/quanto/utils.py b/src/diffusers/quantizers/quanto/utils.py
new file mode 100644
index 000000000000..6f41fd36b43a
--- /dev/null
+++ b/src/diffusers/quantizers/quanto/utils.py
@@ -0,0 +1,60 @@
+import torch.nn as nn
+
+from ...utils import is_accelerate_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+
+def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False):
+ # Quanto imports diffusers internally. These are placed here to avoid circular imports
+ from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8
+
+ def _get_weight_type(dtype: str):
+ return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype]
+
+ def _replace_layers(model, quantization_config, modules_to_not_convert):
+ has_children = list(model.children())
+ if not has_children:
+ return model
+
+ for name, module in model.named_children():
+ _replace_layers(module, quantization_config, modules_to_not_convert)
+
+ if name in modules_to_not_convert:
+ continue
+
+ if isinstance(module, nn.Linear):
+ with init_empty_weights():
+ qlinear = QLinear(
+ in_features=module.in_features,
+ out_features=module.out_features,
+ bias=module.bias is not None,
+ dtype=module.weight.dtype,
+ weights=_get_weight_type(quantization_config.weights_dtype),
+ )
+ model._modules[name] = qlinear
+ model._modules[name].source_cls = type(module)
+ model._modules[name].requires_grad_(False)
+
+ return model
+
+ model = _replace_layers(model, quantization_config, modules_to_not_convert)
+ has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules())
+
+ if not has_been_replaced:
+ logger.warning(
+ f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied."
+ " Please check your model architecture, or submit an issue on Github if you think this is a bug."
+ " https://github.com/huggingface/diffusers/issues/new"
+ )
+
+ # We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict
+ # to match when trying to load weights with load_model_dict_into_meta
+ if pre_quantized:
+ freeze(model)
+
+ return model
diff --git a/src/diffusers/quantizers/torchao/__init__.py b/src/diffusers/quantizers/torchao/__init__.py
new file mode 100644
index 000000000000..c56bf54c2515
--- /dev/null
+++ b/src/diffusers/quantizers/torchao/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .torchao_quantizer import TorchAoHfQuantizer
diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py
new file mode 100644
index 000000000000..f9fb217ed6bd
--- /dev/null
+++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py
@@ -0,0 +1,337 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py
+"""
+
+import importlib
+import types
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from packaging import version
+
+from ...utils import (
+ get_module_from_name,
+ is_torch_available,
+ is_torch_version,
+ is_torchao_available,
+ is_torchao_version,
+ logging,
+)
+from ..base import DiffusersQuantizer
+
+
+if TYPE_CHECKING:
+ from ...models.modeling_utils import ModelMixin
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+ if is_torch_version(">=", "2.5"):
+ SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
+ # At the moment, only int8 is supported for integer quantization dtypes.
+ # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
+ # to support more quantization methods, such as intx_weight_only.
+ torch.int8,
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ torch.uint1,
+ torch.uint2,
+ torch.uint3,
+ torch.uint4,
+ torch.uint5,
+ torch.uint6,
+ torch.uint7,
+ )
+ else:
+ SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
+ torch.int8,
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ )
+
+if is_torchao_available():
+ from torchao.quantization import quantize_
+
+
+def _update_torch_safe_globals():
+ safe_globals = [
+ (torch.uint1, "torch.uint1"),
+ (torch.uint2, "torch.uint2"),
+ (torch.uint3, "torch.uint3"),
+ (torch.uint4, "torch.uint4"),
+ (torch.uint5, "torch.uint5"),
+ (torch.uint6, "torch.uint6"),
+ (torch.uint7, "torch.uint7"),
+ ]
+ try:
+ from torchao.dtypes import NF4Tensor
+ from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
+ from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
+ from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
+
+ safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
+
+ except (ImportError, ModuleNotFoundError) as e:
+ logger.warning(
+ "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
+ )
+ logger.debug(e)
+
+ finally:
+ torch.serialization.add_safe_globals(safe_globals=safe_globals)
+
+
+if (
+ is_torch_available()
+ and is_torch_version(">=", "2.6.0")
+ and is_torchao_available()
+ and is_torchao_version(">=", "0.7.0")
+):
+ _update_torch_safe_globals()
+
+
+logger = logging.get_logger(__name__)
+
+
+def _quantization_type(weight):
+ from torchao.dtypes import AffineQuantizedTensor
+ from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
+
+ if isinstance(weight, AffineQuantizedTensor):
+ return f"{weight.__class__.__name__}({weight._quantization_type()})"
+
+ if isinstance(weight, LinearActivationQuantizedTensor):
+ return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
+
+
+def _linear_extra_repr(self):
+ weight = _quantization_type(self.weight)
+ if weight is None:
+ return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
+ else:
+ return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
+
+
+class TorchAoHfQuantizer(DiffusersQuantizer):
+ r"""
+ Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/.
+ """
+
+ requires_calibration = False
+ required_packages = ["torchao"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ def validate_environment(self, *args, **kwargs):
+ if not is_torchao_available():
+ raise ImportError(
+ "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
+ )
+ torchao_version = version.parse(importlib.metadata.version("torch"))
+ if torchao_version < version.parse("0.7.0"):
+ raise RuntimeError(
+ f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
+ )
+
+ self.offload = False
+
+ device_map = kwargs.get("device_map", None)
+ if isinstance(device_map, dict):
+ if "cpu" in device_map.values() or "disk" in device_map.values():
+ if self.pre_quantized:
+ raise ValueError(
+ "You are attempting to perform cpu/disk offload with a pre-quantized torchao model "
+ "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
+ )
+ else:
+ self.offload = True
+
+ if self.pre_quantized:
+ weights_only = kwargs.get("weights_only", None)
+ if weights_only:
+ torch_version = version.parse(importlib.metadata.version("torch"))
+ if torch_version < version.parse("2.5.0"):
+ # TODO(aryan): TorchAO is compatible with Pytorch >= 2.2 for certain quantization types. Try to see if we can support it in future
+ raise RuntimeError(
+ f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
+ )
+
+ def update_torch_dtype(self, torch_dtype):
+ quant_type = self.quantization_config.quant_type
+
+ if quant_type.startswith("int") or quant_type.startswith("uint"):
+ if torch_dtype is not None and torch_dtype != torch.bfloat16:
+ logger.warning(
+ f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
+ f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
+ )
+
+ if torch_dtype is None:
+ # We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
+ logger.warning(
+ "Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` "
+ "to enable model loading in different precisions. Pass your own `torch_dtype` to specify the "
+ "dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning."
+ )
+ torch_dtype = torch.bfloat16
+
+ return torch_dtype
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ quant_type = self.quantization_config.quant_type
+
+ if quant_type.startswith("int8") or quant_type.startswith("int4"):
+ # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
+ return torch.int8
+ elif quant_type == "uintx_weight_only":
+ return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
+ elif quant_type.startswith("uint"):
+ return {
+ 1: torch.uint1,
+ 2: torch.uint2,
+ 3: torch.uint3,
+ 4: torch.uint4,
+ 5: torch.uint5,
+ 6: torch.uint6,
+ 7: torch.uint7,
+ }[int(quant_type[4])]
+ elif quant_type.startswith("float") or quant_type.startswith("fp"):
+ return torch.bfloat16
+
+ if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
+ return target_dtype
+
+ # We need one of the supported dtypes to be selected in order for accelerate to determine
+ # the total size of modules/parameters for auto device placement.
+ possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"]
+ raise ValueError(
+ f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype "
+ f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the "
+ f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ max_memory = {key: val * 0.9 for key, val in max_memory.items()}
+ return max_memory
+
+ def check_if_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ) -> bool:
+ param_device = kwargs.pop("param_device", None)
+ # Check if the param_name is not in self.modules_to_not_convert
+ if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
+ return False
+ elif param_device == "cpu" and self.offload:
+ # We don't quantize weights that we offload
+ return False
+ else:
+ # We only quantize the weight of nn.Linear
+ module, tensor_name = get_module_from_name(model, param_name)
+ return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Dict[str, Any],
+ unexpected_keys: List[str],
+ **kwargs,
+ ):
+ r"""
+ Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
+ then we move it to the target device. Finally, we quantize the module.
+ """
+ module, tensor_name = get_module_from_name(model, param_name)
+
+ if self.pre_quantized:
+ # If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info
+ # about AffineQuantizedTensor
+ module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
+ if isinstance(module, nn.Linear):
+ module.extra_repr = types.MethodType(_linear_extra_repr, module)
+ else:
+ # As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves
+ module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
+ quantize_(module, self.quantization_config.get_apply_tensor_subclass())
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+ self.modules_to_not_convert.extend(keys_on_cpu)
+
+ # Purge `None`.
+ # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
+ # in case of diffusion transformer models. For language models and others alike, `lm_head`
+ # and tied modules are usually kept in FP32.
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
+
+ model.config.quantization_config = self.quantization_config
+
+ def _process_model_after_weight_loading(self, model: "ModelMixin"):
+ return model
+
+ def is_serializable(self, safe_serialization=None):
+ # TODO(aryan): needs to be tested
+ if safe_serialization:
+ logger.warning(
+ "torchao quantized model does not support safe serialization, please set `safe_serialization` to False."
+ )
+ return False
+
+ _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
+ "0.25.0"
+ )
+
+ if not _is_torchao_serializable:
+ logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
+
+ if self.offload and self.quantization_config.modules_to_not_convert is None:
+ logger.warning(
+ "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
+ "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
+ )
+ return False
+
+ return _is_torchao_serializable
+
+ @property
+ def is_trainable(self):
+ return self.quantization_config.quant_type.startswith("int8")
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index bb9088538653..05cd21cd0034 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -68,6 +68,7 @@
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
+ _import_structure["scheduling_scm"] = ["SCMScheduler"]
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
_import_structure["scheduling_tcd"] = ["TCDScheduler"]
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
@@ -168,13 +169,13 @@
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
from .scheduling_sasolver import SASolverScheduler
+ from .scheduling_scm import SCMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_tcd import TCDScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
-
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py
index 6c2352f2c828..d9d9ae683ad0 100644
--- a/src/diffusers/schedulers/scheduling_ddim_inverse.py
+++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py
@@ -266,7 +266,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.num_inference_steps = num_inference_steps
- # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
+ # "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index 468fdf61a9ef..624d5a5cd4f3 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -142,7 +142,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`.
trained_betas (`np.ndarray`, *optional*):
An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`.
variance_type (`str`, defaults to `"fixed_small"`):
@@ -548,16 +548,12 @@ def __len__(self):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
- if self.custom_timesteps:
+ if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
- num_inference_steps = (
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
- )
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
-
+ prev_t = timestep - 1
return prev_t
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index f377ee6e8c93..20ad7a4c927d 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -639,16 +639,12 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
- if self.custom_timesteps:
+ if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
- num_inference_steps = (
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
- )
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
-
+ prev_t = timestep - 1
return prev_t
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index 6fe8474aab87..6a653f183bba 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -149,6 +149,8 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
+ use_flow_sigmas: Optional[bool] = False,
+ flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
@@ -266,18 +268,28 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas:
- log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
+ elif self.config.use_flow_sigmas:
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -358,8 +370,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
- sigma_t = sigma * alpha_t
+ if self.config.use_flow_sigmas:
+ alpha_t = 1 - sigma
+ sigma_t = sigma
+ else:
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
return alpha_t, sigma_t
@@ -408,7 +424,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -432,7 +448,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
@@ -486,10 +502,13 @@ def convert_model_output(
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction` for the DEISMultistepScheduler."
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the DEISMultistepScheduler."
)
if self.config.thresholding:
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index 7677e37e9426..ed60dd4eaee1 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -136,8 +136,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -174,6 +174,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`.
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
+ flow_shift (`float`, *optional*, defaults to 1.0):
+ The shift value for the timestep schedule for flow matching.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
@@ -218,6 +222,8 @@ def __init__(
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
+ use_flow_sigmas: Optional[bool] = False,
+ flow_shift: Optional[float] = 1.0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
@@ -393,18 +399,29 @@ def set_timesteps(
if self.config.use_karras_sigmas:
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ if self.config.beta_schedule != "squaredcos_cap_v2":
+ timesteps = timesteps.round()
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ if self.config.beta_schedule != "squaredcos_cap_v2":
+ timesteps = timesteps.round()
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ elif self.config.use_flow_sigmas:
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
@@ -493,8 +510,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
return t
def _sigma_to_alpha_sigma_t(self, sigma):
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
- sigma_t = sigma * alpha_t
+ if self.config.use_flow_sigmas:
+ alpha_t = 1 - sigma
+ sigma_t = sigma
+ else:
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
return alpha_t, sigma_t
@@ -556,7 +577,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -580,7 +601,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
@@ -648,10 +669,13 @@ def convert_model_output(
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction` for the DPMSolverMultistepScheduler."
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
)
if self.config.thresholding:
@@ -887,6 +911,7 @@ def multistep_dpm_solver_third_order_update(
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -965,6 +990,15 @@ def multistep_dpm_solver_third_order_update(
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -1071,7 +1105,7 @@ def step(
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
else:
- prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index c26a464518f0..971817f7b777 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -169,6 +169,8 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
+ use_flow_sigmas: Optional[bool] = False,
+ flow_shift: Optional[float] = 1.0,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
@@ -287,11 +289,19 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
timesteps = timesteps.copy().astype(np.int64)
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
+ elif self.config.use_flow_sigmas:
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = (
@@ -379,8 +389,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
- sigma_t = sigma * alpha_t
+ if self.config.use_flow_sigmas:
+ alpha_t = 1 - sigma
+ sigma_t = sigma
+ else:
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
return alpha_t, sigma_t
@@ -429,7 +443,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -453,7 +467,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
@@ -522,10 +536,13 @@ def convert_model_output(
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction` for the DPMSolverMultistepScheduler."
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
)
if self.config.thresholding:
@@ -764,6 +781,7 @@ def multistep_dpm_solver_third_order_update(
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -842,6 +860,15 @@ def multistep_dpm_solver_third_order_update(
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
return x_t
def _init_step_index(self, timestep):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
index a7cc4209fec4..6c9cb975fe34 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
@@ -380,10 +380,10 @@ def set_timesteps(
sigmas = self._convert_to_karras(in_sigmas=sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
@@ -484,7 +484,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -508,7 +508,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 6841a34a6489..bf68d6c99bd6 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -164,6 +164,8 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
+ use_flow_sigmas: Optional[bool] = False,
+ flow_shift: Optional[float] = 1.0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
@@ -264,6 +266,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]:
orders = [1, 2] * (steps // 2)
elif order == 1:
orders = [1] * steps
+
+ if self.config.final_sigmas_type == "zero":
+ orders[-1] = 1
+
return orders
@property
@@ -339,17 +345,24 @@ def set_timesteps(
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas:
- log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ elif self.config.use_flow_sigmas:
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
@@ -448,8 +461,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
- sigma_t = sigma * alpha_t
+ if self.config.use_flow_sigmas:
+ alpha_t = 1 - sigma
+ sigma_t = sigma
+ else:
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
return alpha_t, sigma_t
@@ -498,7 +515,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -522,7 +539,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
@@ -589,10 +606,13 @@ def convert_model_output(
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction` for the DPMSolverSinglestepScheduler."
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the DPMSolverSinglestepScheduler."
)
if self.config.thresholding:
@@ -810,6 +830,7 @@ def singlestep_dpm_solver_third_order_update(
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -907,6 +928,23 @@ def singlestep_dpm_solver_third_order_update(
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
)
+ elif self.config.algorithm_type == "sde-dpmsolver++":
+ assert noise is not None
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s2 * torch.exp(-h)) * sample
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1_1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s2 * torch.exp(-h)) * sample
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) + (-2.0 * h)) / (-2.0 * h) ** 2 - 0.5)) * D2
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
return x_t
def singlestep_dpm_solver_update(
@@ -968,7 +1006,7 @@ def singlestep_dpm_solver_update(
elif order == 2:
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
elif order == 3:
- return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
+ return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample, noise=noise)
else:
raise ValueError(f"Order must be 1, 2, 3, got {order}")
diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py
index d25947d8d331..0617cc44d75a 100644
--- a/src/diffusers/schedulers/scheduling_edm_euler.py
+++ b/src/diffusers/schedulers/scheduling_edm_euler.py
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
import torch
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
Video](https://imagen.research.google/video/paper.pdf) paper).
rho (`float`, *optional*, defaults to 7.0):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
+ final_sigmas_type (`str`, defaults to `"zero"`):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = []
@@ -92,6 +95,7 @@ def __init__(
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
rho: float = 7.0,
+ final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
):
if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
@@ -99,15 +103,24 @@ def __init__(
# setable values
self.num_inference_steps = None
- ramp = torch.linspace(0, 1, num_train_timesteps)
+ sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
if sigma_schedule == "karras":
- sigmas = self._compute_karras_sigmas(ramp)
+ sigmas = self._compute_karras_sigmas(sigmas)
elif sigma_schedule == "exponential":
- sigmas = self._compute_exponential_sigmas(ramp)
+ sigmas = self._compute_exponential_sigmas(sigmas)
self.timesteps = self.precondition_noise(sigmas)
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = sigmas[-1]
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
self.is_scale_input_called = False
@@ -197,7 +210,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(
+ self,
+ num_inference_steps: int = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
+ ):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -206,19 +224,36 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
+ Custom sigmas to use for the denoising process. If not defined, the default behavior when
+ `num_inference_steps` is passed will be used.
"""
self.num_inference_steps = num_inference_steps
- ramp = torch.linspace(0, 1, self.num_inference_steps)
+ if sigmas is None:
+ sigmas = torch.linspace(0, 1, self.num_inference_steps)
+ elif isinstance(sigmas, float):
+ sigmas = torch.tensor(sigmas, dtype=torch.float32)
+ else:
+ sigmas = sigmas
if self.config.sigma_schedule == "karras":
- sigmas = self._compute_karras_sigmas(ramp)
+ sigmas = self._compute_karras_sigmas(sigmas)
elif self.config.sigma_schedule == "exponential":
- sigmas = self._compute_exponential_sigmas(ramp)
+ sigmas = self._compute_exponential_sigmas(sigmas)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = sigmas[-1]
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index 135c48825832..56757f3ca197 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -419,11 +419,11 @@ def set_timesteps(
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
if self.config.final_sigmas_type == "sigma_min":
@@ -517,7 +517,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
def _convert_to_beta(
@@ -540,7 +540,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index 937cae2e47f5..cbb27e5fad63 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -20,10 +20,13 @@
import torch
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput, logging
+from ..utils import BaseOutput, is_scipy_available, logging
from .scheduling_utils import SchedulerMixin
+if is_scipy_available():
+ import scipy.stats
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -51,11 +54,32 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
- timestep_spacing (`str`, defaults to `"linspace"`):
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
+ use_dynamic_shifting (`bool`, defaults to False):
+ Whether to apply timestep shifting on-the-fly based on the image resolution.
+ base_shift (`float`, defaults to 0.5):
+ Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
+ with desired output.
+ max_shift (`float`, defaults to 1.15):
+ Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
+ more exaggerated or stylized.
+ base_image_seq_len (`int`, defaults to 256):
+ The base image sequence length.
+ max_image_seq_len (`int`, defaults to 4096):
+ The maximum image sequence length.
+ invert_sigmas (`bool`, defaults to False):
+ Whether to invert the sigmas.
+ shift_terminal (`float`, defaults to None):
+ The end value of the shifted timestep schedule.
+ use_karras_sigmas (`bool`, defaults to False):
+ Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
+ use_exponential_sigmas (`bool`, defaults to False):
+ Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
+ use_beta_sigmas (`bool`, defaults to False):
+ Whether to use beta sigmas for step sizes in the noise schedule during sampling.
+ time_shift_type (`str`, defaults to "exponential"):
+ The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
"""
_compatibles = []
@@ -66,12 +90,27 @@ def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
- use_dynamic_shifting=False,
+ use_dynamic_shifting: bool = False,
base_shift: Optional[float] = 0.5,
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
+ invert_sigmas: bool = False,
+ shift_terminal: Optional[float] = None,
+ use_karras_sigmas: Optional[bool] = False,
+ use_exponential_sigmas: Optional[bool] = False,
+ use_beta_sigmas: Optional[bool] = False,
+ time_shift_type: str = "exponential",
):
+ if self.config.use_beta_sigmas and not is_scipy_available():
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
+ raise ValueError(
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
+ )
+ if time_shift_type not in {"exponential", "linear"}:
+ raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
@@ -85,10 +124,19 @@ def __init__(
self._step_index = None
self._begin_index = None
+ self._shift = shift
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
+ @property
+ def shift(self):
+ """
+ The value used for shifting.
+ """
+ return self._shift
+
@property
def step_index(self):
"""
@@ -114,6 +162,9 @@ def set_begin_index(self, begin_index: int = 0):
"""
self._begin_index = begin_index
+ def set_shift(self, shift: float):
+ self._shift = shift
+
def scale_noise(
self,
sample: torch.FloatTensor,
@@ -166,47 +217,131 @@ def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+ if self.config.time_shift_type == "exponential":
+ return self._time_shift_exponential(mu, sigma, t)
+ elif self.config.time_shift_type == "linear":
+ return self._time_shift_linear(mu, sigma, t)
+
+ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
+ r"""
+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
+ value.
+
+ Reference:
+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
+
+ Args:
+ t (`torch.Tensor`):
+ A tensor of timesteps to be stretched and shifted.
+
+ Returns:
+ `torch.Tensor`:
+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
+ """
+ one_minus_z = 1 - t
+ scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
+ stretched_t = 1 - (one_minus_z / scale_factor)
+ return stretched_t
def set_timesteps(
self,
- num_inference_steps: int = None,
+ num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
+ timesteps: Optional[List[float]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
+ num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ sigmas (`List[float]`, *optional*):
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
+ automatically.
+ mu (`float`, *optional*):
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
+ shifting.
+ timesteps (`List[float]`, *optional*):
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
+ automatically.
"""
-
if self.config.use_dynamic_shifting and mu is None:
- raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
+ raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
+
+ if sigmas is not None and timesteps is not None:
+ if len(sigmas) != len(timesteps):
+ raise ValueError("`sigmas` and `timesteps` should have the same length")
+
+ if num_inference_steps is not None:
+ if (sigmas is not None and len(sigmas) != num_inference_steps) or (
+ timesteps is not None and len(timesteps) != num_inference_steps
+ ):
+ raise ValueError(
+ "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
+ )
+ else:
+ num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
- if sigmas is None:
- self.num_inference_steps = num_inference_steps
- timesteps = np.linspace(
- self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
- )
+ self.num_inference_steps = num_inference_steps
+ # 1. Prepare default sigmas
+ is_timesteps_provided = timesteps is not None
+
+ if is_timesteps_provided:
+ timesteps = np.array(timesteps).astype(np.float32)
+
+ if sigmas is None:
+ if timesteps is None:
+ timesteps = np.linspace(
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
+ )
sigmas = timesteps / self.config.num_train_timesteps
+ else:
+ sigmas = np.array(sigmas).astype(np.float32)
+ num_inference_steps = len(sigmas)
+ # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
+ # "exponential" or "linear" type is applied
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
- sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
+ sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
- timesteps = sigmas * self.config.num_train_timesteps
+ # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
+ if self.config.shift_terminal:
+ sigmas = self.stretch_shift_to_terminal(sigmas)
+
+ # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
+ if self.config.use_karras_sigmas:
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ elif self.config.use_exponential_sigmas:
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ elif self.config.use_beta_sigmas:
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
- self.timesteps = timesteps.to(device=device)
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+ # 5. Convert sigmas and timesteps to tensors and move to specified device
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
+ if not is_timesteps_provided:
+ timesteps = sigmas * self.config.num_train_timesteps
+ else:
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
+
+ # 6. Append the terminal sigma value.
+ # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
+ # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
+ if self.config.invert_sigmas:
+ sigmas = 1.0 - sigmas
+ timesteps = sigmas * self.config.num_train_timesteps
+ sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
+ else:
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+ self.timesteps = timesteps
+ self.sigmas = sigmas
self._step_index = None
self._begin_index = None
@@ -242,6 +377,7 @@ def step(
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
+ per_token_timesteps: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
@@ -262,14 +398,17 @@ def step(
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
+ per_token_timesteps (`torch.Tensor`, *optional*):
+ The timesteps for each token in the sample.
return_dict (`bool`):
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
- tuple.
+ Whether or not to return a
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
Returns:
- [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
- returned, otherwise a tuple is returned where the first element is the sample tensor.
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`,
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,
+ otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
@@ -280,7 +419,7 @@ def step(
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
@@ -291,21 +430,117 @@ def step(
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
- sigma = self.sigmas[self.step_index]
- sigma_next = self.sigmas[self.step_index + 1]
+ if per_token_timesteps is not None:
+ per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
- prev_sample = sample + (sigma_next - sigma) * model_output
+ sigmas = self.sigmas[:, None, None]
+ lower_mask = sigmas < per_token_sigmas[None] - 1e-6
+ lower_sigmas = lower_mask * sigmas
+ lower_sigmas, _ = lower_sigmas.max(dim=0)
+ dt = (per_token_sigmas - lower_sigmas)[..., None]
+ else:
+ sigma = self.sigmas[self.step_index]
+ sigma_next = self.sigmas[self.step_index + 1]
+ dt = sigma_next - sigma
- # Cast sample back to model compatible dtype
- prev_sample = prev_sample.to(model_output.dtype)
+ prev_sample = sample + dt * model_output
# upon completion increase step index by one
self._step_index += 1
+ if per_token_timesteps is None:
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """Constructs an exponential noise schedule."""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
+ def _convert_to_beta(
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
+ ) -> torch.Tensor:
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ sigmas = np.array(
+ [
+ sigma_min + (ppf * (sigma_max - sigma_min))
+ for ppf in [
+ scipy.stats.beta.ppf(timestep, alpha, beta)
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
+ ]
+ ]
+ )
+ return sigmas
+
+ def _time_shift_exponential(self, mu, sigma, t):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ def _time_shift_linear(self, mu, sigma, t):
+ return mu / (mu + (1 / t - 1) ** sigma)
+
def __len__(self):
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
index cc7f6b8e9c57..2addc5f3eeec 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
@@ -228,13 +228,14 @@ def step(
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
- Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
- tuple.
+ Whether or not to return a
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] tuple.
Returns:
- [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
- returned, otherwise a tuple is returned where the first element is the sample tensor.
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`,
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] is returned,
+ otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
@@ -245,7 +246,7 @@ def step(
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
- " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " `FlowMatchHeunDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py
index 63f38e86ab45..cb6cb9e79565 100644
--- a/src/diffusers/schedulers/scheduling_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_heun_discrete.py
@@ -329,10 +329,10 @@ def set_timesteps(
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -342,7 +342,7 @@ def set_timesteps(
timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
- self.timesteps = timesteps.to(device=device)
+ self.timesteps = timesteps.to(device=device, dtype=torch.float32)
# empty dt and derivative
self.prev_derivative = None
@@ -421,7 +421,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -445,7 +445,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
index f76eb7c371b6..4b388b4d75b3 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
@@ -289,10 +289,10 @@ def set_timesteps(
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
@@ -409,7 +409,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -433,7 +433,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
index bf3b9f1437d2..a2e564e70a0e 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
@@ -288,10 +288,10 @@ def set_timesteps(
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
@@ -422,7 +422,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -446,7 +446,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index f1aa09ab1723..686b686f6870 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -643,16 +643,12 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
- if self.custom_timesteps:
+ if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
- num_inference_steps = (
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
- )
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
-
+ prev_t = timestep - 1
return prev_t
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py
index 0a0900455488..bcf9d9b59e11 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete.py
@@ -302,16 +302,16 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = self._convert_to_karras(in_sigmas=sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
- self.timesteps = torch.from_numpy(timesteps).to(device=device)
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@@ -399,7 +399,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -423,7 +423,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py
index 97665bb5277b..a14797b42f7a 100644
--- a/src/diffusers/schedulers/scheduling_repaint.py
+++ b/src/diffusers/schedulers/scheduling_repaint.py
@@ -319,6 +319,10 @@ def step(
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
# 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
+ # The computation reported in Algorithm 1 Line 5 is incorrect. Line 5 refers to formula (8a) of the same paper,
+ # which tells to sample from a Gaussian distribution with mean "(alpha_prod_t_prev**0.5) * original_image"
+ # and variance "(1 - alpha_prod_t_prev)". This means that the standard Gaussian distribution "noise" should be
+ # scaled by the square root of the variance (as it is done here), however Algorithm 1 Line 5 tells to scale by the variance.
prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
# 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py
index 7188be5caaea..d45c93880bc5 100644
--- a/src/diffusers/schedulers/scheduling_sasolver.py
+++ b/src/diffusers/schedulers/scheduling_sasolver.py
@@ -167,6 +167,8 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
+ use_flow_sigmas: Optional[bool] = False,
+ flow_shift: Optional[float] = 1.0,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
@@ -295,18 +297,28 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas:
- log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
+ elif self.config.use_flow_sigmas:
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -387,8 +399,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
- sigma_t = sigma * alpha_t
+ if self.config.use_flow_sigmas:
+ alpha_t = 1 - sigma
+ sigma_t = sigma
+ else:
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
return alpha_t, sigma_t
@@ -437,7 +453,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -461,7 +477,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
@@ -527,10 +543,13 @@ def convert_model_output(
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction` for the SASolverScheduler."
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the SASolverScheduler."
)
if self.config.thresholding:
diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py
new file mode 100644
index 000000000000..23f47f42a302
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_scm.py
@@ -0,0 +1,265 @@
+# # Copyright 2024 Sana-Sprint Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..schedulers.scheduling_utils import SchedulerMixin
+from ..utils import BaseOutput, logging
+from ..utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->SCM
+class SCMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.Tensor
+ pred_original_sample: Optional[torch.Tensor] = None
+
+
+class SCMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `SCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass
+ documentation for the generic methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ prediction_type (`str`, defaults to `trigflow`):
+ Prediction type of the scheduler function. Currently only supports "trigflow".
+ sigma_data (`float`, defaults to 0.5):
+ The standard deviation of the noise added during multi-step inference.
+ """
+
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ prediction_type: str = "trigflow",
+ sigma_data: float = 0.5,
+ ):
+ """
+ Initialize the SCM scheduler.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ prediction_type (`str`, defaults to `trigflow`):
+ Prediction type of the scheduler function. Currently only supports "trigflow".
+ sigma_data (`float`, defaults to 0.5):
+ The standard deviation of the noise added during multi-step inference.
+ """
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ self._step_index = None
+ self._begin_index = None
+
+ @property
+ def step_index(self):
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ timesteps: torch.Tensor = None,
+ device: Union[str, torch.device] = None,
+ max_timesteps: float = 1.57080,
+ intermediate_timesteps: float = 1.3,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ timesteps (`torch.Tensor`, *optional*):
+ Custom timesteps to use for the denoising process.
+ max_timesteps (`float`, defaults to 1.57080):
+ The maximum timestep value used in the SCM scheduler.
+ intermediate_timesteps (`float`, *optional*, defaults to 1.3):
+ The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
+ """
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ if timesteps is not None and len(timesteps) != num_inference_steps + 1:
+ raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
+
+ if timesteps is not None and max_timesteps is not None:
+ raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
+
+ if timesteps is None and max_timesteps is None:
+ raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
+
+ if intermediate_timesteps is not None and num_inference_steps != 2:
+ raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
+
+ self.num_inference_steps = num_inference_steps
+
+ if timesteps is not None:
+ if isinstance(timesteps, list):
+ self.timesteps = torch.tensor(timesteps, device=device).float()
+ elif isinstance(timesteps, torch.Tensor):
+ self.timesteps = timesteps.to(device).float()
+ else:
+ raise ValueError(f"Unsupported timesteps type: {type(timesteps)}")
+ elif intermediate_timesteps is not None:
+ self.timesteps = torch.tensor([max_timesteps, intermediate_timesteps, 0], device=device).float()
+ else:
+ # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
+ self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
+ print(f"Set timesteps: {self.timesteps}")
+
+ self._step_index = None
+ self._begin_index = None
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: float,
+ sample: torch.FloatTensor,
+ generator: torch.Generator = None,
+ return_dict: bool = True,
+ ) -> Union[SCMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_scm.SCMSchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.SCMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_scm.SCMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # 2. compute alphas, betas
+ t = self.timesteps[self.step_index + 1]
+ s = self.timesteps[self.step_index]
+
+ # 4. Different Parameterization:
+ parameterization = self.config.prediction_type
+
+ if parameterization == "trigflow":
+ pred_x0 = torch.cos(s) * sample - torch.sin(s) * model_output
+ else:
+ raise ValueError(f"Unsupported parameterization: {parameterization}")
+
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
+ # Noise is not used for one-step sampling.
+ if len(self.timesteps) > 1:
+ noise = (
+ randn_tensor(model_output.shape, device=model_output.device, generator=generator)
+ * self.config.sigma_data
+ )
+ prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise
+ else:
+ prev_sample = pred_x0
+
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample, pred_x0)
+
+ return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py
index 580224404c54..5d60383142a4 100644
--- a/src/diffusers/schedulers/scheduling_tcd.py
+++ b/src/diffusers/schedulers/scheduling_tcd.py
@@ -680,16 +680,12 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
- if self.custom_timesteps:
+ if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
- num_inference_steps = (
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
- )
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
-
+ prev_t = timestep - 1
return prev_t
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 195e9c8477a2..01500426305c 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -206,6 +206,8 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
+ use_flow_sigmas: Optional[bool] = False,
+ flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
@@ -347,11 +349,47 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
- sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ log_sigmas = np.log(sigmas)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = sigmas[-1]
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_beta_sigmas:
- sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ log_sigmas = np.log(sigmas)
+ sigmas = np.flip(sigmas).copy()
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = sigmas[-1]
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
+ elif self.config.use_flow_sigmas:
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
+ sigmas = 1.0 - alphas
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = sigmas[-1]
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
@@ -442,8 +480,12 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
- sigma_t = sigma * alpha_t
+ if self.config.use_flow_sigmas:
+ alpha_t = 1 - sigma
+ sigma_t = sigma
+ else:
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
+ sigma_t = sigma * alpha_t
return alpha_t, sigma_t
@@ -492,7 +534,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
@@ -516,7 +558,7 @@ def _convert_to_beta(
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
- sigmas = torch.Tensor(
+ sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
@@ -572,10 +614,13 @@ def convert_model_output(
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
x0_pred = alpha_t * sample - sigma_t * model_output
+ elif self.config.prediction_type == "flow_prediction":
+ sigma_t = self.sigmas[self.step_index]
+ x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
- " `v_prediction` for the UniPCMultistepScheduler."
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
+ "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py
index f20224b19009..83f31b72c10b 100644
--- a/src/diffusers/schedulers/scheduling_utils.py
+++ b/src/diffusers/schedulers/scheduling_utils.py
@@ -19,6 +19,7 @@
import torch
from huggingface_hub.utils import validate_hf_hub_args
+from typing_extensions import Self
from ..utils import BaseOutput, PushToHubMixin
@@ -99,7 +100,7 @@ def from_pretrained(
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
- ):
+ ) -> Self:
r"""
Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository.
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index 9c898ad141ee..c570bac733db 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -43,6 +43,9 @@ def set_seed(seed: int):
Args:
seed (`int`): The seed to set.
+
+ Returns:
+ `None`
"""
random.seed(seed)
np.random.seed(seed)
@@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ for the given timesteps using the provided noise scheduler.
+
+ Args:
+ noise_scheduler (`NoiseScheduler`):
+ An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
+ the SNR values.
+ timesteps (`torch.Tensor`):
+ A tensor of timesteps for which the SNR is computed.
+
+ Returns:
+ `torch.Tensor`: A tensor containing the computed SNR values for each timestep.
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
@@ -234,7 +248,13 @@ def _set_state_dict_into_text_encoder(
def compute_density_for_timestep_sampling(
- weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
+ weighting_scheme: str,
+ batch_size: int,
+ logit_mean: float = None,
+ logit_std: float = None,
+ mode_scale: float = None,
+ device: Union[torch.device, str] = "cpu",
+ generator: Optional[torch.Generator] = None,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
@@ -244,14 +264,13 @@ def compute_density_for_timestep_sampling(
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
- # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
- u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
- u = torch.rand(size=(batch_size,), device="cpu")
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
- u = torch.rand(size=(batch_size,), device="cpu")
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
return u
@@ -284,7 +303,9 @@ def free_memory():
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif is_torch_npu_available():
- torch_npu.empty_cache()
+ torch_npu.npu.empty_cache()
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
+ torch.xpu.empty_cache()
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
@@ -379,7 +400,7 @@ def __init__(
@classmethod
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
- _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
+ _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path)
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index c8f64adf3e8a..50a470772772 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME,
+ GGUF_FILE_EXTENSION,
HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION,
@@ -66,7 +67,10 @@
is_bs4_available,
is_flax_available,
is_ftfy_available,
+ is_gguf_available,
+ is_gguf_version,
is_google_colab,
+ is_hf_hub_version,
is_inflect_available,
is_invisible_watermark_available,
is_k_diffusion_available,
@@ -75,6 +79,8 @@
is_matplotlib_available,
is_note_seq_available,
is_onnx_available,
+ is_optimum_quanto_available,
+ is_optimum_quanto_version,
is_peft_available,
is_peft_version,
is_safetensors_available,
@@ -86,6 +92,9 @@
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
+ is_torch_xla_version,
+ is_torchao_available,
+ is_torchao_version,
is_torchsde_available,
is_torchvision_available,
is_transformers_available,
@@ -95,7 +104,7 @@
is_xformers_available,
requires_backends,
)
-from .loading_utils import get_module_from_name, load_image, load_video
+from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
@@ -110,6 +119,7 @@
unscale_lora_layers,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
+from .remote_utils import remote_decode
from .state_dict_utils import (
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
@@ -117,6 +127,7 @@
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
)
+from .typing_utils import _get_detailed_type, _is_valid_type
logger = get_logger(__name__)
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index 553ac5d1bb27..fa12318f4714 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
SAFETENSORS_FILE_EXTENSION = "safetensors"
+GGUF_FILE_EXTENSION = "gguf"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
@@ -55,3 +56,14 @@
if USE_PEFT_BACKEND and _CHECK_PEFT:
dep_version_check("peft")
+
+
+DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
+DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
+DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
+DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
+
+
+ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
+ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
+ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
diff --git a/src/diffusers/utils/dummy_bitsandbytes_objects.py b/src/diffusers/utils/dummy_bitsandbytes_objects.py
new file mode 100644
index 000000000000..2dc589428de9
--- /dev/null
+++ b/src/diffusers/utils/dummy_bitsandbytes_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class BitsAndBytesConfig(metaclass=DummyObject):
+ _backends = ["bitsandbytes"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["bitsandbytes"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["bitsandbytes"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["bitsandbytes"])
diff --git a/src/diffusers/utils/dummy_gguf_objects.py b/src/diffusers/utils/dummy_gguf_objects.py
new file mode 100644
index 000000000000..4a6d9a060a13
--- /dev/null
+++ b/src/diffusers/utils/dummy_gguf_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class GGUFQuantizationConfig(metaclass=DummyObject):
+ _backends = ["gguf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["gguf"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["gguf"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["gguf"])
diff --git a/src/diffusers/utils/dummy_optimum_quanto_objects.py b/src/diffusers/utils/dummy_optimum_quanto_objects.py
new file mode 100644
index 000000000000..44f8eaffc246
--- /dev/null
+++ b/src/diffusers/utils/dummy_optimum_quanto_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class QuantoConfig(metaclass=DummyObject):
+ _backends = ["optimum_quanto"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["optimum_quanto"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["optimum_quanto"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["optimum_quanto"])
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 10d0399a6761..6edbd737e32c 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -2,6 +2,74 @@
from ..utils import DummyObject, requires_backends
+class FasterCacheConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class HookRegistry(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+def apply_faster_cache(*args, **kwargs):
+ requires_backends(apply_faster_cache, ["torch"])
+
+
+def apply_pyramid_attention_broadcast(*args, **kwargs):
+ requires_backends(apply_pyramid_attention_broadcast, ["torch"])
+
+
+class AllegroTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AsymmetricAutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]
@@ -32,6 +100,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderDC(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]
@@ -47,6 +130,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLAllegro(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderKLCogVideoX(metaclass=DummyObject):
_backends = ["torch"]
@@ -62,6 +160,66 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLLTXVideo(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLMagvit(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLMochi(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
_backends = ["torch"]
@@ -77,6 +235,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class AutoencoderKLWan(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderOobleck(metaclass=DummyObject):
_backends = ["torch"]
@@ -107,6 +280,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class CacheMixin(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class CogVideoXTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -137,6 +325,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class CogView4Transformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ConsisIDTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ConsistencyDecoderVAE(metaclass=DummyObject):
_backends = ["torch"]
@@ -167,6 +385,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class ControlNetUnionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ControlNetXSAdapter(metaclass=DummyObject):
_backends = ["torch"]
@@ -197,6 +430,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class EasyAnimateTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class FluxControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -287,6 +535,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"]
@@ -332,6 +595,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class LTXVideoTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class Lumina2Transformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class LuminaNextDiT2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -347,6 +640,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class MochiTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]
@@ -392,6 +700,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class MultiControlNetModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class OmniGenTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -422,6 +760,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class SanaTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class SD3ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -677,6 +1030,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class WanTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])
@@ -1485,6 +1853,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class SCMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 9046a4f73533..b28fba948149 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends
+class AllegroPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -17,7 +32,412 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AltDiffusionPipeline(metaclass=DummyObject):
+class AltDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AmusedImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AmusedInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AmusedPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffPAGPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffSDXLPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffSparseControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDM2Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDM2ProjectionModel(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDM2UNet2DConditionModel(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDMPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AuraFlowPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CLIPImageProjection(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogVideoXFunControlPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogVideoXImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogVideoXPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogVideoXVideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogView3PlusPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogView4ControlPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogView4Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ConsisIDPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CycleDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class EasyAnimateControlPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class EasyAnimateInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -32,7 +452,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AmusedImg2ImgPipeline(metaclass=DummyObject):
+class EasyAnimatePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -47,7 +467,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AmusedInpaintPipeline(metaclass=DummyObject):
+class FluxControlImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -62,7 +482,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AmusedPipeline(metaclass=DummyObject):
+class FluxControlInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -77,7 +497,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffControlNetPipeline(metaclass=DummyObject):
+class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -92,7 +512,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffPAGPipeline(metaclass=DummyObject):
+class FluxControlNetInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -107,7 +527,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffPipeline(metaclass=DummyObject):
+class FluxControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -122,7 +542,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffSDXLPipeline(metaclass=DummyObject):
+class FluxControlPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -137,7 +557,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffSparseControlNetPipeline(metaclass=DummyObject):
+class FluxFillPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -152,7 +572,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject):
+class FluxImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -167,7 +587,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
+class FluxInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -182,7 +602,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AudioLDM2Pipeline(metaclass=DummyObject):
+class FluxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -197,7 +617,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AudioLDM2ProjectionModel(metaclass=DummyObject):
+class FluxPriorReduxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -212,7 +632,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AudioLDM2UNet2DConditionModel(metaclass=DummyObject):
+class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -227,7 +647,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AudioLDMPipeline(metaclass=DummyObject):
+class HunyuanDiTPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -242,7 +662,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AuraFlowPipeline(metaclass=DummyObject):
+class HunyuanDiTPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -257,7 +677,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CLIPImageProjection(metaclass=DummyObject):
+class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -272,7 +692,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogVideoXFunControlPipeline(metaclass=DummyObject):
+class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -287,7 +707,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogVideoXImageToVideoPipeline(metaclass=DummyObject):
+class HunyuanVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -302,7 +722,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogVideoXPipeline(metaclass=DummyObject):
+class I2VGenXLPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -317,7 +737,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogVideoXVideoToVideoPipeline(metaclass=DummyObject):
+class IFImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -332,7 +752,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogView3PlusPipeline(metaclass=DummyObject):
+class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -347,7 +767,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CycleDiffusionPipeline(metaclass=DummyObject):
+class IFInpaintingPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -362,7 +782,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
+class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -377,7 +797,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetInpaintPipeline(metaclass=DummyObject):
+class IFPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -392,7 +812,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetPipeline(metaclass=DummyObject):
+class IFSuperResolutionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -407,7 +827,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxImg2ImgPipeline(metaclass=DummyObject):
+class ImageTextPipelineOutput(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -422,7 +842,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxInpaintPipeline(metaclass=DummyObject):
+class Kandinsky3Img2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -437,7 +857,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxPipeline(metaclass=DummyObject):
+class Kandinsky3Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -452,7 +872,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
+class KandinskyCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -467,7 +887,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanDiTPAGPipeline(metaclass=DummyObject):
+class KandinskyImg2ImgCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -482,7 +902,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanDiTPipeline(metaclass=DummyObject):
+class KandinskyImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -497,7 +917,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class I2VGenXLPipeline(metaclass=DummyObject):
+class KandinskyInpaintCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -512,7 +932,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFImg2ImgPipeline(metaclass=DummyObject):
+class KandinskyInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -527,7 +947,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject):
+class KandinskyPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -542,7 +962,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFInpaintingPipeline(metaclass=DummyObject):
+class KandinskyPriorPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -557,7 +977,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject):
+class KandinskyV22CombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -572,7 +992,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFPipeline(metaclass=DummyObject):
+class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -587,7 +1007,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFSuperResolutionPipeline(metaclass=DummyObject):
+class KandinskyV22ControlnetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -602,7 +1022,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class ImageTextPipelineOutput(metaclass=DummyObject):
+class KandinskyV22Img2ImgCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -617,7 +1037,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class Kandinsky3Img2ImgPipeline(metaclass=DummyObject):
+class KandinskyV22Img2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -632,7 +1052,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class Kandinsky3Pipeline(metaclass=DummyObject):
+class KandinskyV22InpaintCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -647,7 +1067,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyCombinedPipeline(metaclass=DummyObject):
+class KandinskyV22InpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -662,7 +1082,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyImg2ImgCombinedPipeline(metaclass=DummyObject):
+class KandinskyV22Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -677,7 +1097,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyImg2ImgPipeline(metaclass=DummyObject):
+class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -692,7 +1112,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyInpaintCombinedPipeline(metaclass=DummyObject):
+class KandinskyV22PriorPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -707,7 +1127,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyInpaintPipeline(metaclass=DummyObject):
+class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -722,7 +1142,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyPipeline(metaclass=DummyObject):
+class LatentConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -737,7 +1157,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyPriorPipeline(metaclass=DummyObject):
+class LattePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -752,7 +1172,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22CombinedPipeline(metaclass=DummyObject):
+class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -767,7 +1187,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject):
+class LEditsPPPipelineStableDiffusion(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -782,7 +1202,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22ControlnetPipeline(metaclass=DummyObject):
+class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -797,7 +1217,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22Img2ImgCombinedPipeline(metaclass=DummyObject):
+class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -812,7 +1232,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22Img2ImgPipeline(metaclass=DummyObject):
+class LTXImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -827,7 +1247,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22InpaintCombinedPipeline(metaclass=DummyObject):
+class LTXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -842,7 +1262,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22InpaintPipeline(metaclass=DummyObject):
+class Lumina2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -857,7 +1277,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22Pipeline(metaclass=DummyObject):
+class Lumina2Text2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -872,7 +1292,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject):
+class LuminaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -887,7 +1307,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22PriorPipeline(metaclass=DummyObject):
+class LuminaText2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -902,7 +1322,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject):
+class MarigoldDepthPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -917,7 +1337,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LatentConsistencyModelPipeline(metaclass=DummyObject):
+class MarigoldIntrinsicsPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -932,7 +1352,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LattePipeline(metaclass=DummyObject):
+class MarigoldNormalsPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -947,7 +1367,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LDMTextToImagePipeline(metaclass=DummyObject):
+class MochiPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -962,7 +1382,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LEditsPPPipelineStableDiffusion(metaclass=DummyObject):
+class MusicLDMPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -977,7 +1397,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
+class OmniGenPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -992,7 +1412,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LuminaText2ImgPipeline(metaclass=DummyObject):
+class PaintByExamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1007,7 +1427,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class MarigoldDepthPipeline(metaclass=DummyObject):
+class PIAPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1022,7 +1442,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class MarigoldNormalsPipeline(metaclass=DummyObject):
+class PixArtAlphaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1037,7 +1457,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class MusicLDMPipeline(metaclass=DummyObject):
+class PixArtSigmaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1052,7 +1472,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PaintByExamplePipeline(metaclass=DummyObject):
+class PixArtSigmaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1067,7 +1487,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PIAPipeline(metaclass=DummyObject):
+class ReduxImageEncoder(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1082,7 +1502,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PixArtAlphaPipeline(metaclass=DummyObject):
+class SanaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1097,7 +1517,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PixArtSigmaPAGPipeline(metaclass=DummyObject):
+class SanaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1112,7 +1532,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PixArtSigmaPipeline(metaclass=DummyObject):
+class SanaSprintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1247,6 +1667,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusion3ControlNetInpaintingPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusion3ControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1292,6 +1727,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusion3PAGImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusion3PAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1637,6 +2087,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusionPAGInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1847,6 +2312,51 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusionXLControlNetUnionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionXLControlNetUnionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionXLControlNetUnionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -2222,6 +2732,51 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class WanImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WanPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WanVideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class WuerstchenCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/dummy_torchao_objects.py b/src/diffusers/utils/dummy_torchao_objects.py
new file mode 100644
index 000000000000..16f0f6a55f64
--- /dev/null
+++ b/src/diffusers/utils/dummy_torchao_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class TorchAoConfig(metaclass=DummyObject):
+ _backends = ["torchao"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torchao"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torchao"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torchao"])
diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py
index f0cf953924ad..5d0752af8983 100644
--- a/src/diffusers/utils/dynamic_modules_utils.py
+++ b/src/diffusers/utils/dynamic_modules_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -325,7 +325,7 @@ def get_cached_module_file(
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
# that hash, to only copy when there is a modification but it seems overkill for now).
# The only reason we do the copy is to avoid putting too many folders in sys.path.
- shutil.copy(resolved_module_file, submodule_path / module_file)
+ shutil.copyfile(resolved_module_file, submodule_path / module_file)
for module_needed in modules_needed:
if len(module_needed.split(".")) == 2:
module_needed = "/".join(module_needed.split("."))
@@ -333,7 +333,7 @@ def get_cached_module_file(
if not os.path.exists(submodule_path / module_folder):
os.makedirs(submodule_path / module_folder)
module_needed = f"{module_needed}.py"
- shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
@@ -350,7 +350,7 @@ def get_cached_module_file(
module_folder = module_file.split("/")[0]
if not os.path.exists(submodule_path / module_folder):
os.makedirs(submodule_path / module_folder)
- shutil.copy(resolved_module_file, submodule_path / module_file)
+ shutil.copyfile(resolved_module_file, submodule_path / module_file)
# Make sure we also have every file with relative
for module_needed in modules_needed:
diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py
index 00805433ceba..30d2c8bebd8e 100644
--- a/src/diffusers/utils/export_utils.py
+++ b/src/diffusers/utils/export_utils.py
@@ -3,7 +3,7 @@
import struct
import tempfile
from contextlib import contextmanager
-from typing import List, Union
+from typing import List, Optional, Union
import numpy as np
import PIL.Image
@@ -139,8 +139,31 @@ def _legacy_export_to_video(
def export_to_video(
- video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
+ output_video_path: str = None,
+ fps: int = 10,
+ quality: float = 5.0,
+ bitrate: Optional[int] = None,
+ macro_block_size: Optional[int] = 16,
) -> str:
+ """
+ quality:
+ Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to
+ prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead.
+ Specifying a fixed bitrate using `bitrate` disables this parameter.
+
+ bitrate:
+ Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead.
+ Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter
+ rather than specifiying a fixed bitrate with this parameter.
+
+ macro_block_size:
+ Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number
+ imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs
+ are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic
+ feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some
+ codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock.
+ """
# TODO: Dhruv. Remove by Diffusers release 0.33.0
# Added to prevent breaking existing code
if not is_imageio_available():
@@ -177,7 +200,9 @@ def export_to_video(
elif isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
- with imageio.get_writer(output_video_path, fps=fps) as writer:
+ with imageio.get_writer(
+ output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size
+ ) as writer:
for frame in video_frames:
writer.append_data(frame)
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index 448e92509732..f80f96a3425d 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,13 +19,13 @@
import re
import sys
import tempfile
-import traceback
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Union
from uuid import uuid4
from huggingface_hub import (
+ DDUFEntry,
ModelCard,
ModelCardData,
create_repo,
@@ -34,7 +34,7 @@
snapshot_download,
upload_folder,
)
-from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
+from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import (
EntryNotFoundError,
@@ -196,78 +196,6 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
-# Old default cache path, potentially to be migrated.
-# This logic was more or less taken from `transformers`, with the following differences:
-# - Diffusers doesn't use custom environment variables to specify the cache path.
-# - There is no need to migrate the cache format, just move the files to the new location.
-hf_cache_home = os.path.expanduser(
- os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
-)
-old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
-
-
-def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
- if new_cache_dir is None:
- new_cache_dir = HF_HUB_CACHE
- if old_cache_dir is None:
- old_cache_dir = old_diffusers_cache
-
- old_cache_dir = Path(old_cache_dir).expanduser()
- new_cache_dir = Path(new_cache_dir).expanduser()
- for old_blob_path in old_cache_dir.glob("**/blobs/*"):
- if old_blob_path.is_file() and not old_blob_path.is_symlink():
- new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
- new_blob_path.parent.mkdir(parents=True, exist_ok=True)
- os.replace(old_blob_path, new_blob_path)
- try:
- os.symlink(new_blob_path, old_blob_path)
- except OSError:
- logger.warning(
- "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded."
- )
- # At this point, old_cache_dir contains symlinks to the new cache (it can still be used).
-
-
-cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt")
-if not os.path.isfile(cache_version_file):
- cache_version = 0
-else:
- with open(cache_version_file) as f:
- try:
- cache_version = int(f.read())
- except ValueError:
- cache_version = 0
-
-if cache_version < 1:
- old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0
- if old_cache_is_not_empty:
- logger.warning(
- "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your "
- "existing cached models. This is a one-time operation, you can interrupt it or run it "
- "later by calling `diffusers.utils.hub_utils.move_cache()`."
- )
- try:
- move_cache()
- except Exception as e:
- trace = "\n".join(traceback.format_tb(e.__traceback__))
- logger.error(
- f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
- "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole "
- "message and we will do our best to help."
- )
-
-if cache_version < 1:
- try:
- os.makedirs(HF_HUB_CACHE, exist_ok=True)
- with open(cache_version_file, "w") as f:
- f.write("1")
- except Exception:
- logger.warning(
- f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure "
- "the directory exists and can be written to."
- )
-
-
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
@@ -291,9 +219,26 @@ def _get_model_file(
user_agent: Optional[Union[Dict, str]] = None,
revision: Optional[str] = None,
commit_hash: Optional[str] = None,
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- if os.path.isfile(pretrained_model_name_or_path):
+
+ if dduf_entries:
+ if subfolder is not None:
+ raise ValueError(
+ "DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
+ "Please check the DDUF structure"
+ )
+ model_file = (
+ weights_name
+ if pretrained_model_name_or_path == ""
+ else "/".join([pretrained_model_name_or_path, weights_name])
+ )
+ if model_file in dduf_entries:
+ return model_file
+ else:
+ raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_entries.keys()}.")
+ elif os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
@@ -393,22 +338,6 @@ def _get_model_file(
) from e
-# Adapted from
-# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976
-# Differences are in parallelization of shard downloads and checking if shards are present.
-
-
-def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames):
- shards_path = os.path.join(local_dir, subfolder)
- shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
- for shard_file in shard_filenames:
- if not os.path.exists(shard_file):
- raise ValueError(
- f"{shards_path} does not appear to have a file named {shard_file} which is "
- "required according to the checkpoint index."
- )
-
-
def _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
@@ -419,6 +348,7 @@ def _get_checkpoint_shard_files(
user_agent=None,
revision=None,
subfolder="",
+ dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
"""
For a given model:
@@ -430,11 +360,18 @@ def _get_checkpoint_shard_files(
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
- if not os.path.isfile(index_filename):
- raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
+ if dduf_entries:
+ if index_filename not in dduf_entries:
+ raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
+ else:
+ if not os.path.isfile(index_filename):
+ raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
- with open(index_filename, "r") as f:
- index = json.loads(f.read())
+ if dduf_entries:
+ index = json.loads(dduf_entries[index_filename].read_text())
+ else:
+ with open(index_filename, "r") as f:
+ index = json.loads(f.read())
original_shard_filenames = sorted(set(index["weight_map"].values()))
sharded_metadata = index["metadata"]
@@ -443,11 +380,22 @@ def _get_checkpoint_shard_files(
shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
# First, let's deal with local folder.
- if os.path.isdir(pretrained_model_name_or_path):
- _check_if_shards_exist_locally(
- pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
- )
- return shards_path, sharded_metadata
+ if os.path.isdir(pretrained_model_name_or_path) or dduf_entries:
+ shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
+ for shard_file in shard_filenames:
+ if dduf_entries:
+ if shard_file not in dduf_entries:
+ raise FileNotFoundError(
+ f"{shards_path} does not appear to have a file named {shard_file} which is "
+ "required according to the checkpoint index."
+ )
+ else:
+ if not os.path.exists(shard_file):
+ raise FileNotFoundError(
+ f"{shards_path} does not appear to have a file named {shard_file} which is "
+ "required according to the checkpoint index."
+ )
+ return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns = original_shard_filenames
@@ -455,50 +403,43 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"]
- if not local_files_only:
- # `model_info` call must guarded with the above condition.
- model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
- for shard_file in original_shard_filenames:
- shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
- if not shard_file_present:
- raise EnvironmentError(
- f"{shards_path} does not appear to have a file named {shard_file} which is "
- "required according to the checkpoint index."
- )
-
- try:
- # Load from URL
- cached_folder = snapshot_download(
- pretrained_model_name_or_path,
- cache_dir=cache_dir,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- allow_patterns=allow_patterns,
- ignore_patterns=ignore_patterns,
- user_agent=user_agent,
- )
- if subfolder is not None:
- cached_folder = os.path.join(cached_folder, subfolder)
-
- # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
- # we don't have to catch them here. We have also dealt with EntryNotFoundError.
- except HTTPError as e:
+ # `model_info` call must guarded with the above condition.
+ model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
+ for shard_file in original_shard_filenames:
+ shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
+ if not shard_file_present:
raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
- " again after checking your internet connection."
- ) from e
+ f"{shards_path} does not appear to have a file named {shard_file} which is "
+ "required according to the checkpoint index."
+ )
- # If `local_files_only=True`, `cached_folder` may not contain all the shard files.
- elif local_files_only:
- _check_if_shards_exist_locally(
- local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
+ try:
+ # Load from URL
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ ignore_patterns=ignore_patterns,
+ user_agent=user_agent,
)
if subfolder is not None:
- cached_folder = os.path.join(cache_dir, subfolder)
+ cached_folder = os.path.join(cached_folder, subfolder)
+
+ # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
+ # we don't have to catch them here. We have also dealt with EntryNotFoundError.
+ except HTTPError as e:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
+ " again after checking your internet connection."
+ ) from e
+
+ cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
- return cached_folder, sharded_metadata
+ return cached_filenames, sharded_metadata
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
@@ -564,7 +505,8 @@ def push_to_hub(
commit_message (`str`, *optional*):
Message to commit while pushing. Default to `"Upload {object}"`.
private (`bool`, *optional*):
- Whether or not the repository created should be private.
+ Whether to make the repo private. If `None` (default), the repo will be public unless the
+ organization's default is private. This value is ignored if the repo already exists.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. The token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index f1323bf00ea4..f61116aaaf6c 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -25,7 +25,6 @@
from typing import Any, Union
from huggingface_hub.utils import is_jinja_available # noqa: F401
-from packaging import version
from packaging.version import Version, parse
from . import logging
@@ -52,36 +51,30 @@
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
-_torch_version = "N/A"
-if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
- _torch_available = importlib.util.find_spec("torch") is not None
- if _torch_available:
+_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
+
+
+def _is_package_available(pkg_name: str):
+ pkg_exists = importlib.util.find_spec(pkg_name) is not None
+ pkg_version = "N/A"
+
+ if pkg_exists:
try:
- _torch_version = importlib_metadata.version("torch")
- logger.info(f"PyTorch version {_torch_version} available.")
- except importlib_metadata.PackageNotFoundError:
- _torch_available = False
+ pkg_version = importlib_metadata.version(pkg_name)
+ logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
+ except (ImportError, importlib_metadata.PackageNotFoundError):
+ pkg_exists = False
+
+ return pkg_exists, pkg_version
+
+
+if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ _torch_available, _torch_version = _is_package_available("torch")
+
else:
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False
-_torch_xla_available = importlib.util.find_spec("torch_xla") is not None
-if _torch_xla_available:
- try:
- _torch_xla_version = importlib_metadata.version("torch_xla")
- logger.info(f"PyTorch XLA version {_torch_xla_version} available.")
- except ImportError:
- _torch_xla_available = False
-
-# check whether torch_npu is available
-_torch_npu_available = importlib.util.find_spec("torch_npu") is not None
-if _torch_npu_available:
- try:
- _torch_npu_version = importlib_metadata.version("torch_npu")
- logger.info(f"torch_npu version {_torch_npu_version} available.")
- except ImportError:
- _torch_npu_available = False
-
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
@@ -97,40 +90,12 @@
_flax_available = False
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
- _safetensors_available = importlib.util.find_spec("safetensors") is not None
- if _safetensors_available:
- try:
- _safetensors_version = importlib_metadata.version("safetensors")
- logger.info(f"Safetensors version {_safetensors_version} available.")
- except importlib_metadata.PackageNotFoundError:
- _safetensors_available = False
+ _safetensors_available, _safetensors_version = _is_package_available("safetensors")
+
else:
logger.info("Disabling Safetensors because USE_TF is set")
_safetensors_available = False
-_transformers_available = importlib.util.find_spec("transformers") is not None
-try:
- _transformers_version = importlib_metadata.version("transformers")
- logger.debug(f"Successfully imported transformers version {_transformers_version}")
-except importlib_metadata.PackageNotFoundError:
- _transformers_available = False
-
-
-_inflect_available = importlib.util.find_spec("inflect") is not None
-try:
- _inflect_version = importlib_metadata.version("inflect")
- logger.debug(f"Successfully imported inflect version {_inflect_version}")
-except importlib_metadata.PackageNotFoundError:
- _inflect_available = False
-
-
-_unidecode_available = importlib.util.find_spec("unidecode") is not None
-try:
- _unidecode_version = importlib_metadata.version("unidecode")
- logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
-except importlib_metadata.PackageNotFoundError:
- _unidecode_available = False
-
_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available:
@@ -142,7 +107,9 @@
"onnxruntime-openvino",
"ort_nightly_directml",
"onnxruntime-rocm",
+ "onnxruntime-migraphx",
"onnxruntime-training",
+ "onnxruntime-vitisai",
)
_onnxruntime_version = None
# For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
@@ -178,85 +145,6 @@
except importlib_metadata.PackageNotFoundError:
_opencv_available = False
-_scipy_available = importlib.util.find_spec("scipy") is not None
-try:
- _scipy_version = importlib_metadata.version("scipy")
- logger.debug(f"Successfully imported scipy version {_scipy_version}")
-except importlib_metadata.PackageNotFoundError:
- _scipy_available = False
-
-_librosa_available = importlib.util.find_spec("librosa") is not None
-try:
- _librosa_version = importlib_metadata.version("librosa")
- logger.debug(f"Successfully imported librosa version {_librosa_version}")
-except importlib_metadata.PackageNotFoundError:
- _librosa_available = False
-
-_accelerate_available = importlib.util.find_spec("accelerate") is not None
-try:
- _accelerate_version = importlib_metadata.version("accelerate")
- logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
-except importlib_metadata.PackageNotFoundError:
- _accelerate_available = False
-
-_xformers_available = importlib.util.find_spec("xformers") is not None
-try:
- _xformers_version = importlib_metadata.version("xformers")
- if _torch_available:
- _torch_version = importlib_metadata.version("torch")
- if version.Version(_torch_version) < version.Version("1.12"):
- raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")
-
- logger.debug(f"Successfully imported xformers version {_xformers_version}")
-except importlib_metadata.PackageNotFoundError:
- _xformers_available = False
-
-_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
-try:
- _k_diffusion_version = importlib_metadata.version("k_diffusion")
- logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
-except importlib_metadata.PackageNotFoundError:
- _k_diffusion_available = False
-
-_note_seq_available = importlib.util.find_spec("note_seq") is not None
-try:
- _note_seq_version = importlib_metadata.version("note_seq")
- logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
-except importlib_metadata.PackageNotFoundError:
- _note_seq_available = False
-
-_wandb_available = importlib.util.find_spec("wandb") is not None
-try:
- _wandb_version = importlib_metadata.version("wandb")
- logger.debug(f"Successfully imported wandb version {_wandb_version }")
-except importlib_metadata.PackageNotFoundError:
- _wandb_available = False
-
-
-_tensorboard_available = importlib.util.find_spec("tensorboard")
-try:
- _tensorboard_version = importlib_metadata.version("tensorboard")
- logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
-except importlib_metadata.PackageNotFoundError:
- _tensorboard_available = False
-
-
-_compel_available = importlib.util.find_spec("compel")
-try:
- _compel_version = importlib_metadata.version("compel")
- logger.debug(f"Successfully imported compel version {_compel_version}")
-except importlib_metadata.PackageNotFoundError:
- _compel_available = False
-
-
-_ftfy_available = importlib.util.find_spec("ftfy") is not None
-try:
- _ftfy_version = importlib_metadata.version("ftfy")
- logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
-except importlib_metadata.PackageNotFoundError:
- _ftfy_available = False
-
-
_bs4_available = importlib.util.find_spec("bs4") is not None
try:
# importlib metadata under different name
@@ -265,13 +153,6 @@
except importlib_metadata.PackageNotFoundError:
_bs4_available = False
-_torchsde_available = importlib.util.find_spec("torchsde") is not None
-try:
- _torchsde_version = importlib_metadata.version("torchsde")
- logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
-except importlib_metadata.PackageNotFoundError:
- _torchsde_available = False
-
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
try:
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
@@ -279,65 +160,42 @@
except importlib_metadata.PackageNotFoundError:
_invisible_watermark_available = False
-
-_peft_available = importlib.util.find_spec("peft") is not None
-try:
- _peft_version = importlib_metadata.version("peft")
- logger.debug(f"Successfully imported peft version {_peft_version}")
-except importlib_metadata.PackageNotFoundError:
- _peft_available = False
-
-_torchvision_available = importlib.util.find_spec("torchvision") is not None
-try:
- _torchvision_version = importlib_metadata.version("torchvision")
- logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
-except importlib_metadata.PackageNotFoundError:
- _torchvision_available = False
-
-_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None
-try:
- _sentencepiece_version = importlib_metadata.version("sentencepiece")
- logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}")
-except importlib_metadata.PackageNotFoundError:
- _sentencepiece_available = False
-
-_matplotlib_available = importlib.util.find_spec("matplotlib") is not None
-try:
- _matplotlib_version = importlib_metadata.version("matplotlib")
- logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
-except importlib_metadata.PackageNotFoundError:
- _matplotlib_available = False
-
-_timm_available = importlib.util.find_spec("timm") is not None
-if _timm_available:
- try:
- _timm_version = importlib_metadata.version("timm")
- logger.info(f"Timm version {_timm_version} available.")
- except importlib_metadata.PackageNotFoundError:
- _timm_available = False
-
-
-def is_timm_available():
- return _timm_available
-
-
-_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
-try:
- _bitsandbytes_version = importlib_metadata.version("bitsandbytes")
- logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
-except importlib_metadata.PackageNotFoundError:
- _bitsandbytes_available = False
-
-_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
-
-_imageio_available = importlib.util.find_spec("imageio") is not None
-if _imageio_available:
+_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
+_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
+_transformers_available, _transformers_version = _is_package_available("transformers")
+_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
+_inflect_available, _inflect_version = _is_package_available("inflect")
+_unidecode_available, _unidecode_version = _is_package_available("unidecode")
+_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
+_note_seq_available, _note_seq_version = _is_package_available("note_seq")
+_wandb_available, _wandb_version = _is_package_available("wandb")
+_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
+_compel_available, _compel_version = _is_package_available("compel")
+_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
+_torchsde_available, _torchsde_version = _is_package_available("torchsde")
+_peft_available, _peft_version = _is_package_available("peft")
+_torchvision_available, _torchvision_version = _is_package_available("torchvision")
+_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib")
+_timm_available, _timm_version = _is_package_available("timm")
+_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
+_imageio_available, _imageio_version = _is_package_available("imageio")
+_ftfy_available, _ftfy_version = _is_package_available("ftfy")
+_scipy_available, _scipy_version = _is_package_available("scipy")
+_librosa_available, _librosa_version = _is_package_available("librosa")
+_accelerate_available, _accelerate_version = _is_package_available("accelerate")
+_xformers_available, _xformers_version = _is_package_available("xformers")
+_gguf_available, _gguf_version = _is_package_available("gguf")
+_torchao_available, _torchao_version = _is_package_available("torchao")
+_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
+_torchao_available, _torchao_version = _is_package_available("torchao")
+
+_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
+if _optimum_quanto_available:
try:
- _imageio_version = importlib_metadata.version("imageio")
- logger.debug(f"Successfully imported imageio version {_imageio_version}")
-
+ _optimum_quanto_version = importlib_metadata.version("optimum_quanto")
+ logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
except importlib_metadata.PackageNotFoundError:
- _imageio_available = False
+ _optimum_quanto_available = False
def is_torch_available():
@@ -460,6 +318,22 @@ def is_imageio_available():
return _imageio_available
+def is_gguf_available():
+ return _gguf_available
+
+
+def is_torchao_available():
+ return _torchao_available
+
+
+def is_optimum_quanto_available():
+ return _optimum_quanto_available
+
+
+def is_timm_available():
+ return _timm_available
+
+
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -593,6 +467,21 @@ def is_imageio_available():
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
"""
+# docstyle-ignore
+GGUF_IMPORT_ERROR = """
+{0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf`
+"""
+
+TORCHAO_IMPORT_ERROR = """
+{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
+torchao`
+"""
+
+QUANTO_IMPORT_ERROR = """
+{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip
+install optimum-quanto`
+"""
+
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -618,6 +507,9 @@ def is_imageio_available():
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
+ ("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
+ ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
+ ("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
]
)
@@ -700,6 +592,21 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
+def is_torch_xla_version(operation: str, version: str):
+ """
+ Compares the current torch_xla version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A string version of torch_xla
+ """
+ if not is_torch_xla_available:
+ return False
+ return compare_versions(parse(_torch_xla_version), operation, version)
+
+
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
@@ -715,6 +622,21 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
+def is_hf_hub_version(operation: str, version: str):
+ """
+ Compares the current Hugging Face Hub version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _hf_hub_available:
+ return False
+ return compare_versions(parse(_hf_hub_version), operation, version)
+
+
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -740,7 +662,7 @@ def is_peft_version(operation: str, version: str):
version (`str`):
A version string
"""
- if not _peft_version:
+ if not _peft_available:
return False
return compare_versions(parse(_peft_version), operation, version)
@@ -754,11 +676,41 @@ def is_bitsandbytes_version(operation: str, version: str):
version (`str`):
A version string
"""
- if not _bitsandbytes_version:
+ if not _bitsandbytes_available:
return False
return compare_versions(parse(_bitsandbytes_version), operation, version)
+def is_gguf_version(operation: str, version: str):
+ """
+ Compares the current Accelerate version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _gguf_available:
+ return False
+ return compare_versions(parse(_gguf_version), operation, version)
+
+
+def is_torchao_version(operation: str, version: str):
+ """
+ Compares the current torchao version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _torchao_available:
+ return False
+ return compare_versions(parse(_torchao_version), operation, version)
+
+
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
@@ -774,6 +726,21 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)
+def is_optimum_quanto_version(operation: str, version: str):
+ """
+ Compares the current Accelerate version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _optimum_quanto_available:
+ return False
+ return compare_versions(parse(_optimum_quanto_version), operation, version)
+
+
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py
index bac24fa23e63..fd66aaa4da6e 100644
--- a/src/diffusers/utils/loading_utils.py
+++ b/src/diffusers/utils/loading_utils.py
@@ -148,3 +148,15 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
module = new_module
tensor_name = splits[-1]
return module, tensor_name
+
+
+def get_submodule_by_name(root_module, module_path: str):
+ current = root_module
+ parts = module_path.split(".")
+ for part in parts:
+ if part.isdigit():
+ idx = int(part)
+ current = current[idx] # e.g., for nn.ModuleList or nn.Sequential
+ else:
+ current = getattr(current, part)
+ return current
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index dcc78a547a13..d1269fbc5f20 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -180,6 +180,8 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
# layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
+ # for now we know that the "bias" keys are only associated with `lora_B`.
+ lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
lora_config_kwargs = {
"r": r,
@@ -188,6 +190,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
"alpha_pattern": alpha_pattern,
"target_modules": target_modules,
"use_dora": use_dora,
+ "lora_bias": lora_bias,
}
return lora_config_kwargs
@@ -254,26 +257,18 @@ def get_module_weight(weight_for_adapter, module_name):
return block_weight
- # iterate over each adapter, make it active and set the corresponding scaling weight
- for adapter_name, weight in zip(adapter_names, weights):
- for module_name, module in model.named_modules():
- if isinstance(module, BaseTunerLayer):
- # For backward compatbility with previous PEFT versions
- if hasattr(module, "set_adapter"):
- module.set_adapter(adapter_name)
- else:
- module.active_adapter = adapter_name
- module.set_scale(adapter_name, get_module_weight(weight, module_name))
-
- # set multiple active adapters
- for module in model.modules():
+ for module_name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
- # For backward compatbility with previous PEFT versions
+ # For backward compatibility with previous PEFT versions, set multiple active adapters
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_names)
else:
module.active_adapter = adapter_names
+ # Set the scaling weight for each adapter for this module
+ for adapter_name, weight in zip(adapter_names, weights):
+ module.set_scale(adapter_name, get_module_weight(weight, module_name))
+
def check_peft_version(min_version: str) -> None:
r"""
diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py
new file mode 100644
index 000000000000..6494dc14171a
--- /dev/null
+++ b/src/diffusers/utils/remote_utils.py
@@ -0,0 +1,425 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import io
+import json
+from typing import List, Literal, Optional, Union, cast
+
+import requests
+
+from .deprecation_utils import deprecate
+from .import_utils import is_safetensors_available, is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+ from ..image_processor import VaeImageProcessor
+ from ..video_processor import VideoProcessor
+
+ if is_safetensors_available():
+ import safetensors.torch
+
+ DTYPE_MAP = {
+ "float16": torch.float16,
+ "float32": torch.float32,
+ "bfloat16": torch.bfloat16,
+ "uint8": torch.uint8,
+ }
+
+
+from PIL import Image
+
+
+def detect_image_type(data: bytes) -> str:
+ if data.startswith(b"\xff\xd8"):
+ return "jpeg"
+ elif data.startswith(b"\x89PNG\r\n\x1a\n"):
+ return "png"
+ elif data.startswith(b"GIF87a") or data.startswith(b"GIF89a"):
+ return "gif"
+ elif data.startswith(b"BM"):
+ return "bmp"
+ return "unknown"
+
+
+def check_inputs_decode(
+ endpoint: str,
+ tensor: "torch.Tensor",
+ processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
+ do_scaling: bool = True,
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+ output_type: Literal["mp4", "pil", "pt"] = "pil",
+ return_type: Literal["mp4", "pil", "pt"] = "pil",
+ image_format: Literal["png", "jpg"] = "jpg",
+ partial_postprocess: bool = False,
+ input_tensor_type: Literal["binary"] = "binary",
+ output_tensor_type: Literal["binary"] = "binary",
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+):
+ if tensor.ndim == 3 and height is None and width is None:
+ raise ValueError("`height` and `width` required for packed latents.")
+ if (
+ output_type == "pt"
+ and return_type == "pil"
+ and not partial_postprocess
+ and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
+ ):
+ raise ValueError("`processor` is required.")
+ if do_scaling and scaling_factor is None:
+ deprecate(
+ "do_scaling",
+ "1.0.0",
+ "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
+ standard_warn=False,
+ )
+
+
+def postprocess_decode(
+ response: requests.Response,
+ processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
+ output_type: Literal["mp4", "pil", "pt"] = "pil",
+ return_type: Literal["mp4", "pil", "pt"] = "pil",
+ partial_postprocess: bool = False,
+):
+ if output_type == "pt" or (output_type == "pil" and processor is not None):
+ output_tensor = response.content
+ parameters = response.headers
+ shape = json.loads(parameters["shape"])
+ dtype = parameters["dtype"]
+ torch_dtype = DTYPE_MAP[dtype]
+ output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
+ if output_type == "pt":
+ if partial_postprocess:
+ if return_type == "pil":
+ output = [Image.fromarray(image.numpy()) for image in output_tensor]
+ if len(output) == 1:
+ output = output[0]
+ elif return_type == "pt":
+ output = output_tensor
+ else:
+ if processor is None or return_type == "pt":
+ output = output_tensor
+ else:
+ if isinstance(processor, VideoProcessor):
+ output = cast(
+ List[Image.Image],
+ processor.postprocess_video(output_tensor, output_type="pil")[0],
+ )
+ else:
+ output = cast(
+ Image.Image,
+ processor.postprocess(output_tensor, output_type="pil")[0],
+ )
+ elif output_type == "pil" and return_type == "pil" and processor is None:
+ output = Image.open(io.BytesIO(response.content)).convert("RGB")
+ detected_format = detect_image_type(response.content)
+ output.format = detected_format
+ elif output_type == "pil" and processor is not None:
+ if return_type == "pil":
+ output = [
+ Image.fromarray(image)
+ for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
+ ]
+ elif return_type == "pt":
+ output = output_tensor
+ elif output_type == "mp4" and return_type == "mp4":
+ output = response.content
+ return output
+
+
+def prepare_decode(
+ tensor: "torch.Tensor",
+ processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
+ do_scaling: bool = True,
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+ output_type: Literal["mp4", "pil", "pt"] = "pil",
+ image_format: Literal["png", "jpg"] = "jpg",
+ partial_postprocess: bool = False,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+):
+ headers = {}
+ parameters = {
+ "image_format": image_format,
+ "output_type": output_type,
+ "partial_postprocess": partial_postprocess,
+ "shape": list(tensor.shape),
+ "dtype": str(tensor.dtype).split(".")[-1],
+ }
+ if do_scaling and scaling_factor is not None:
+ parameters["scaling_factor"] = scaling_factor
+ if do_scaling and shift_factor is not None:
+ parameters["shift_factor"] = shift_factor
+ if do_scaling and scaling_factor is None:
+ parameters["do_scaling"] = do_scaling
+ elif do_scaling and scaling_factor is None and shift_factor is None:
+ parameters["do_scaling"] = do_scaling
+ if height is not None and width is not None:
+ parameters["height"] = height
+ parameters["width"] = width
+ headers["Content-Type"] = "tensor/binary"
+ headers["Accept"] = "tensor/binary"
+ if output_type == "pil" and image_format == "jpg" and processor is None:
+ headers["Accept"] = "image/jpeg"
+ elif output_type == "pil" and image_format == "png" and processor is None:
+ headers["Accept"] = "image/png"
+ elif output_type == "mp4":
+ headers["Accept"] = "text/plain"
+ tensor_data = safetensors.torch._tobytes(tensor, "tensor")
+ return {"data": tensor_data, "params": parameters, "headers": headers}
+
+
+def remote_decode(
+ endpoint: str,
+ tensor: "torch.Tensor",
+ processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
+ do_scaling: bool = True,
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+ output_type: Literal["mp4", "pil", "pt"] = "pil",
+ return_type: Literal["mp4", "pil", "pt"] = "pil",
+ image_format: Literal["png", "jpg"] = "jpg",
+ partial_postprocess: bool = False,
+ input_tensor_type: Literal["binary"] = "binary",
+ output_tensor_type: Literal["binary"] = "binary",
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]:
+ """
+ Hugging Face Hybrid Inference that allow running VAE decode remotely.
+
+ Args:
+ endpoint (`str`):
+ Endpoint for Remote Decode.
+ tensor (`torch.Tensor`):
+ Tensor to be decoded.
+ processor (`VaeImageProcessor` or `VideoProcessor`, *optional*):
+ Used with `return_type="pt"`, and `return_type="pil"` for Video models.
+ do_scaling (`bool`, default `True`, *optional*):
+ **DEPRECATED**. **pass `scaling_factor`/`shift_factor` instead.** **still set
+ do_scaling=None/do_scaling=False for no scaling until option is removed** When `True` scaling e.g. `latents
+ / self.vae.config.scaling_factor` is applied remotely. If `False`, input must be passed with scaling
+ applied.
+ scaling_factor (`float`, *optional*):
+ Scaling is applied when passed e.g. [`latents /
+ self.vae.config.scaling_factor`](https://github.com/huggingface/diffusers/blob/7007febae5cff000d4df9059d9cf35133e8b2ca9/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L1083C37-L1083C77).
+ - SD v1: 0.18215
+ - SD XL: 0.13025
+ - Flux: 0.3611
+ If `None`, input must be passed with scaling applied.
+ shift_factor (`float`, *optional*):
+ Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`.
+ - Flux: 0.1159
+ If `None`, input must be passed with scaling applied.
+ output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
+ **Endpoint** output type. Subject to change. Report feedback on preferred type.
+
+ `"mp4": Supported by video models. Endpoint returns `bytes` of video. `"pil"`: Supported by image and video
+ models.
+ Image models: Endpoint returns `bytes` of an image in `image_format`. Video models: Endpoint returns
+ `torch.Tensor` with partial `postprocessing` applied.
+ Requires `processor` as a flag (any `None` value will work).
+ `"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`.
+ With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor.
+
+ Recommendations:
+ `"pt"` with `partial_postprocess=True` is the smallest transfer for full quality. `"pt"` with
+ `partial_postprocess=False` is the most compatible with third party code. `"pil"` with
+ `image_format="jpg"` is the smallest transfer overall.
+
+ return_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
+ **Function** return type.
+
+ `"mp4": Function returns `bytes` of video. `"pil"`: Function returns `PIL.Image.Image`.
+ With `output_type="pil" no further processing is applied. With `output_type="pt" a `PIL.Image.Image` is
+ created.
+ `partial_postprocess=False` `processor` is required. `partial_postprocess=True` `processor` is
+ **not** required.
+ `"pt"`: Function returns `torch.Tensor`.
+ `processor` is **not** required. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
+ denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.
+
+ image_format (`"png"` or `"jpg"`, default `jpg`):
+ Used with `output_type="pil"`. Endpoint returns `jpg` or `png`.
+
+ partial_postprocess (`bool`, default `False`):
+ Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
+ denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.
+
+ input_tensor_type (`"binary"`, default `"binary"`):
+ Tensor transfer type.
+
+ output_tensor_type (`"binary"`, default `"binary"`):
+ Tensor transfer type.
+
+ height (`int`, **optional**):
+ Required for `"packed"` latents.
+
+ width (`int`, **optional**):
+ Required for `"packed"` latents.
+
+ Returns:
+ output (`Image.Image` or `List[Image.Image]` or `bytes` or `torch.Tensor`).
+ """
+ if input_tensor_type == "base64":
+ deprecate(
+ "input_tensor_type='base64'",
+ "1.0.0",
+ "input_tensor_type='base64' is deprecated. Using `binary`.",
+ standard_warn=False,
+ )
+ input_tensor_type = "binary"
+ if output_tensor_type == "base64":
+ deprecate(
+ "output_tensor_type='base64'",
+ "1.0.0",
+ "output_tensor_type='base64' is deprecated. Using `binary`.",
+ standard_warn=False,
+ )
+ output_tensor_type = "binary"
+ check_inputs_decode(
+ endpoint,
+ tensor,
+ processor,
+ do_scaling,
+ scaling_factor,
+ shift_factor,
+ output_type,
+ return_type,
+ image_format,
+ partial_postprocess,
+ input_tensor_type,
+ output_tensor_type,
+ height,
+ width,
+ )
+ kwargs = prepare_decode(
+ tensor=tensor,
+ processor=processor,
+ do_scaling=do_scaling,
+ scaling_factor=scaling_factor,
+ shift_factor=shift_factor,
+ output_type=output_type,
+ image_format=image_format,
+ partial_postprocess=partial_postprocess,
+ height=height,
+ width=width,
+ )
+ response = requests.post(endpoint, **kwargs)
+ if not response.ok:
+ raise RuntimeError(response.json())
+ output = postprocess_decode(
+ response=response,
+ processor=processor,
+ output_type=output_type,
+ return_type=return_type,
+ partial_postprocess=partial_postprocess,
+ )
+ return output
+
+
+def check_inputs_encode(
+ endpoint: str,
+ image: Union["torch.Tensor", Image.Image],
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+):
+ pass
+
+
+def postprocess_encode(
+ response: requests.Response,
+):
+ output_tensor = response.content
+ parameters = response.headers
+ shape = json.loads(parameters["shape"])
+ dtype = parameters["dtype"]
+ torch_dtype = DTYPE_MAP[dtype]
+ output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
+ return output_tensor
+
+
+def prepare_encode(
+ image: Union["torch.Tensor", Image.Image],
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+):
+ headers = {}
+ parameters = {}
+ if scaling_factor is not None:
+ parameters["scaling_factor"] = scaling_factor
+ if shift_factor is not None:
+ parameters["shift_factor"] = shift_factor
+ if isinstance(image, torch.Tensor):
+ data = safetensors.torch._tobytes(image.contiguous(), "tensor")
+ parameters["shape"] = list(image.shape)
+ parameters["dtype"] = str(image.dtype).split(".")[-1]
+ else:
+ buffer = io.BytesIO()
+ image.save(buffer, format="PNG")
+ data = buffer.getvalue()
+ return {"data": data, "params": parameters, "headers": headers}
+
+
+def remote_encode(
+ endpoint: str,
+ image: Union["torch.Tensor", Image.Image],
+ scaling_factor: Optional[float] = None,
+ shift_factor: Optional[float] = None,
+) -> "torch.Tensor":
+ """
+ Hugging Face Hybrid Inference that allow running VAE encode remotely.
+
+ Args:
+ endpoint (`str`):
+ Endpoint for Remote Decode.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ Image to be encoded.
+ scaling_factor (`float`, *optional*):
+ Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`].
+ - SD v1: 0.18215
+ - SD XL: 0.13025
+ - Flux: 0.3611
+ If `None`, input must be passed with scaling applied.
+ shift_factor (`float`, *optional*):
+ Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`.
+ - Flux: 0.1159
+ If `None`, input must be passed with scaling applied.
+
+ Returns:
+ output (`torch.Tensor`).
+ """
+ check_inputs_encode(
+ endpoint,
+ image,
+ scaling_factor,
+ shift_factor,
+ )
+ kwargs = prepare_encode(
+ image=image,
+ scaling_factor=scaling_factor,
+ shift_factor=shift_factor,
+ )
+ response = requests.post(endpoint, **kwargs)
+ if not response.ok:
+ raise RuntimeError(response.json())
+ output = postprocess_encode(
+ response=response,
+ )
+ return output
diff --git a/src/diffusers/utils/source_code_parsing_utils.py b/src/diffusers/utils/source_code_parsing_utils.py
new file mode 100644
index 000000000000..5f94711c21d8
--- /dev/null
+++ b/src/diffusers/utils/source_code_parsing_utils.py
@@ -0,0 +1,52 @@
+import ast
+import importlib
+import inspect
+import textwrap
+
+
+class ReturnNameVisitor(ast.NodeVisitor):
+ """Thanks to ChatGPT for pairing."""
+
+ def __init__(self):
+ self.return_names = []
+
+ def visit_Return(self, node):
+ # Check if the return value is a tuple.
+ if isinstance(node.value, ast.Tuple):
+ for elt in node.value.elts:
+ if isinstance(elt, ast.Name):
+ self.return_names.append(elt.id)
+ else:
+ try:
+ self.return_names.append(ast.unparse(elt))
+ except Exception:
+ self.return_names.append(str(elt))
+ else:
+ if isinstance(node.value, ast.Name):
+ self.return_names.append(node.value.id)
+ else:
+ try:
+ self.return_names.append(ast.unparse(node.value))
+ except Exception:
+ self.return_names.append(str(node.value))
+ self.generic_visit(node)
+
+ def _determine_parent_module(self, cls):
+ from diffusers import DiffusionPipeline
+ from diffusers.models.modeling_utils import ModelMixin
+
+ if issubclass(cls, DiffusionPipeline):
+ return "pipelines"
+ elif issubclass(cls, ModelMixin):
+ return "models"
+ else:
+ raise NotImplementedError
+
+ def get_ast_tree(self, cls, attribute_name="encode_prompt"):
+ parent_module_name = self._determine_parent_module(cls)
+ main_module = importlib.import_module(f"diffusers.{parent_module_name}")
+ current_cls_module = getattr(main_module, cls.__name__)
+ source_code = inspect.getsource(getattr(current_cls_module, attribute_name))
+ source_code = textwrap.dedent(source_code)
+ tree = ast.parse(source_code)
+ return tree
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index 1179b113d636..e62f245f9ed1 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -32,6 +32,7 @@
is_bitsandbytes_available,
is_compel_available,
is_flax_available,
+ is_gguf_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
@@ -39,6 +40,7 @@
is_timm_available,
is_torch_available,
is_torch_version,
+ is_torchao_available,
is_torchsde_available,
is_transformers_available,
)
@@ -57,6 +59,7 @@
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
+BIG_GPU_MEMORY = int(os.getenv("BIG_GPU_MEMORY", 40))
if is_torch_available():
import torch
@@ -83,7 +86,12 @@
) from e
logger.info(f"torch_device overrode to {torch_device}")
else:
- torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+ if torch.cuda.is_available():
+ torch_device = "cuda"
+ elif torch.xpu.is_available():
+ torch_device = "xpu"
+ else:
+ torch_device = "cpu"
is_torch_higher_equal_than_1_12 = version.parse(
version.parse(torch.__version__).base_version
) >= version.parse("1.12")
@@ -93,6 +101,8 @@
mps_backend_registered = hasattr(torch.backends, "mps")
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
+ from .torch_utils import get_torch_cuda_device_capability
+
def torch_all_close(a, b, *args, **kwargs):
if not is_torch_available():
@@ -274,6 +284,20 @@ def require_torch_gpu(test_case):
)
+def require_torch_cuda_compatibility(expected_compute_capability):
+ def decorator(test_case):
+ if not torch.cuda.is_available():
+ return unittest.skip(test_case)
+ else:
+ current_compute_capability = get_torch_cuda_device_capability()
+ return unittest.skipUnless(
+ float(current_compute_capability) == float(expected_compute_capability),
+ "Test not supported for this compute capability.",
+ )
+
+ return decorator
+
+
# These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
@@ -296,6 +320,21 @@ def require_torch_multi_gpu(test_case):
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
+def require_torch_multi_accelerator(test_case):
+ """
+ Decorator marking a test that requires a multi-accelerator setup (in PyTorch). These tests are skipped on a machine
+ without multiple hardware accelerators.
+ """
+ if not is_torch_available():
+ return unittest.skip("test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(
+ torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
+ )(test_case)
+
+
def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
@@ -310,6 +349,51 @@ def require_torch_accelerator_with_fp64(test_case):
)
+def require_big_gpu_with_torch_cuda(test_case):
+ """
+ Decorator marking a test that requires a bigger GPU (24GB) for execution. Some example pipelines: Flux, SD3, Cog,
+ etc.
+ """
+ if not is_torch_available():
+ return unittest.skip("test requires PyTorch")(test_case)
+
+ import torch
+
+ if not torch.cuda.is_available():
+ return unittest.skip("test requires PyTorch CUDA")(test_case)
+
+ device_properties = torch.cuda.get_device_properties(0)
+ total_memory = device_properties.total_memory / (1024**3)
+ return unittest.skipUnless(
+ total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
+ )(test_case)
+
+
+def require_big_accelerator(test_case):
+ """
+ Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
+ Flux, SD3, Cog, etc.
+ """
+ if not is_torch_available():
+ return unittest.skip("test requires PyTorch")(test_case)
+
+ import torch
+
+ if not (torch.cuda.is_available() or torch.xpu.is_available()):
+ return unittest.skip("test requires PyTorch CUDA")(test_case)
+
+ if torch.xpu.is_available():
+ device_properties = torch.xpu.get_device_properties(0)
+ else:
+ device_properties = torch.cuda.get_device_properties(0)
+
+ total_memory = device_properties.total_memory / (1024**3)
+ return unittest.skipUnless(
+ total_memory >= BIG_GPU_MEMORY,
+ f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
+ )(test_case)
+
+
def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless(
@@ -352,6 +436,14 @@ def require_note_seq(test_case):
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
+def require_accelerator(test_case):
+ """
+ Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
+ hardware accelerator available.
+ """
+ return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
+
+
def require_torchsde(test_case):
"""
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
@@ -425,7 +517,7 @@ def decorator(test_case):
def require_accelerate_version_greater(accelerate_version):
def decorator(test_case):
- correct_accelerate_version = is_peft_available() and version.parse(
+ correct_accelerate_version = is_accelerate_available() and version.parse(
version.parse(importlib.metadata.version("accelerate")).base_version
) > version.parse(accelerate_version)
return unittest.skipUnless(
@@ -447,6 +539,42 @@ def decorator(test_case):
return decorator
+def require_hf_hub_version_greater(hf_hub_version):
+ def decorator(test_case):
+ correct_hf_hub_version = version.parse(
+ version.parse(importlib.metadata.version("huggingface_hub")).base_version
+ ) > version.parse(hf_hub_version)
+ return unittest.skipUnless(
+ correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_gguf_version_greater_or_equal(gguf_version):
+ def decorator(test_case):
+ correct_gguf_version = is_gguf_available() and version.parse(
+ version.parse(importlib.metadata.version("gguf")).base_version
+ ) >= version.parse(gguf_version)
+ return unittest.skipUnless(
+ correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_torchao_version_greater_or_equal(torchao_version):
+ def decorator(test_case):
+ correct_torchao_version = is_torchao_available() and version.parse(
+ version.parse(importlib.metadata.version("torchao")).base_version
+ ) >= version.parse(torchao_version)
+ return unittest.skipUnless(
+ correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
+ )(test_case)
+
+ return decorator
+
+
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
@@ -486,10 +614,10 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
return arry
-def load_pt(url: str):
+def load_pt(url: str, map_location: str):
response = requests.get(url)
response.raise_for_status()
- arry = torch.load(BytesIO(response.content))
+ arry = torch.load(BytesIO(response.content), map_location=map_location)
return arry
@@ -1000,12 +1128,51 @@ def _is_torch_fp64_available(device):
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
if is_torch_available():
# Behaviour flags
- BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
+ BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
# Function definitions
- BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
- BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
- BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
+ BACKEND_EMPTY_CACHE = {
+ "cuda": torch.cuda.empty_cache,
+ "xpu": torch.xpu.empty_cache,
+ "cpu": None,
+ "mps": torch.mps.empty_cache,
+ "default": None,
+ }
+ BACKEND_DEVICE_COUNT = {
+ "cuda": torch.cuda.device_count,
+ "xpu": torch.xpu.device_count,
+ "cpu": lambda: 0,
+ "mps": lambda: 0,
+ "default": 0,
+ }
+ BACKEND_MANUAL_SEED = {
+ "cuda": torch.cuda.manual_seed,
+ "xpu": torch.xpu.manual_seed,
+ "cpu": torch.manual_seed,
+ "mps": torch.mps.manual_seed,
+ "default": torch.manual_seed,
+ }
+ BACKEND_RESET_PEAK_MEMORY_STATS = {
+ "cuda": torch.cuda.reset_peak_memory_stats,
+ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+ BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.reset_max_memory_allocated,
+ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+ BACKEND_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.max_memory_allocated,
+ "xpu": getattr(torch.xpu, "max_memory_allocated", None),
+ "cpu": 0,
+ "mps": 0,
+ "default": 0,
+ }
# This dispatches a defined function according to the accelerator from the function definitions.
@@ -1036,6 +1203,18 @@ def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
+def backend_reset_peak_memory_stats(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
+
+
+def backend_reset_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
+
+
+def backend_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
+
+
# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str):
@@ -1092,3 +1271,6 @@ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
+ update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
+ update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
+ update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py
index 0cf75b4fad4e..3c8911773e39 100644
--- a/src/diffusers/utils/torch_utils.py
+++ b/src/diffusers/utils/torch_utils.py
@@ -102,6 +102,9 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)
+ # fftn does not support bfloat16
+ elif x.dtype == torch.bfloat16:
+ x = x.to(dtype=torch.float32)
# FFT
x_freq = fftn(x, dim=(-2, -1))
@@ -146,3 +149,13 @@ def apply_freeu(
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
return hidden_states, res_hidden_states
+
+
+def get_torch_cuda_device_capability():
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ compute_capability = torch.cuda.get_device_capability(device)
+ compute_capability = f"{compute_capability[0]}.{compute_capability[1]}"
+ return float(compute_capability)
+ else:
+ return None
diff --git a/src/diffusers/utils/typing_utils.py b/src/diffusers/utils/typing_utils.py
new file mode 100644
index 000000000000..2b5b1a4f5ab5
--- /dev/null
+++ b/src/diffusers/utils/typing_utils.py
@@ -0,0 +1,91 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Typing utilities: Utilities related to type checking and validation
+"""
+
+from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args, get_origin
+
+
+def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
+ """
+ Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
+ the correct type as well.
+ """
+ if not isinstance(class_or_tuple, tuple):
+ class_or_tuple = (class_or_tuple,)
+
+ # Unpack unions
+ unpacked_class_or_tuple = []
+ for t in class_or_tuple:
+ if get_origin(t) is Union:
+ unpacked_class_or_tuple.extend(get_args(t))
+ else:
+ unpacked_class_or_tuple.append(t)
+ class_or_tuple = tuple(unpacked_class_or_tuple)
+
+ if Any in class_or_tuple:
+ return True
+
+ obj_type = type(obj)
+ # Classes with obj's type
+ class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
+
+ # Singular types (e.g. int, ControlNet, ...)
+ # Untyped collections (e.g. List, but not List[int])
+ elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
+ if () in elem_class_or_tuple:
+ return True
+ # Typed lists or sets
+ elif obj_type in (list, set):
+ return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
+ # Typed tuples
+ elif obj_type is tuple:
+ return any(
+ # Tuples with any length and single type (e.g. Tuple[int, ...])
+ (len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
+ or
+ # Tuples with fixed length and any types (e.g. Tuple[int, str])
+ (len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
+ for t in elem_class_or_tuple
+ )
+ # Typed dicts
+ elif obj_type is dict:
+ return any(
+ all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
+ for kt, vt in elem_class_or_tuple
+ )
+
+ else:
+ return False
+
+
+def _get_detailed_type(obj: Any) -> Type:
+ """
+ Gets a detailed type for an object, including nested types for collections.
+ """
+ obj_type = type(obj)
+
+ if obj_type in (list, set):
+ obj_origin_type = List if obj_type is list else Set
+ elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
+ return obj_origin_type[elems_type]
+ elif obj_type is tuple:
+ return Tuple[tuple(_get_detailed_type(x) for x in obj)]
+ elif obj_type is dict:
+ keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
+ values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
+ return Dict[keys_type, values_type]
+ else:
+ return obj_type
diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py
index 9e2727b85377..2da782b463d4 100644
--- a/src/diffusers/video_processor.py
+++ b/src/diffusers/video_processor.py
@@ -67,7 +67,7 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[
# ensure the input is a list of videos:
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
- # - if it is is a single video, it is convereted to a list of one video.
+ # - if it is a single video, it is convereted to a list of one video.
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
video = list(video)
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py
index 601f51b1263e..e197cb6859fa 100644
--- a/tests/fixtures/custom_pipeline/pipeline.py
+++ b/tests/fixtures/custom_pipeline/pipeline.py
@@ -18,7 +18,7 @@
import torch
-from diffusers import DiffusionPipeline, ImagePipelineOutput
+from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
class CustomLocalPipeline(DiffusionPipeline):
@@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
- def __init__(self, unet, scheduler):
+ def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
diff --git a/tests/fixtures/custom_pipeline/what_ever.py b/tests/fixtures/custom_pipeline/what_ever.py
index 8ceeb4211e37..bbe7f4f16bd8 100644
--- a/tests/fixtures/custom_pipeline/what_ever.py
+++ b/tests/fixtures/custom_pipeline/what_ever.py
@@ -18,6 +18,7 @@
import torch
+from diffusers import SchedulerMixin, UNet2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -33,7 +34,7 @@ class CustomLocalPipeline(DiffusionPipeline):
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
- def __init__(self, unet, scheduler):
+ def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py
new file mode 100644
index 000000000000..d8f41fc2b1ae
--- /dev/null
+++ b/tests/hooks/test_group_offloading.py
@@ -0,0 +1,214 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+
+from diffusers.models import ModelMixin
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.utils import get_logger
+from diffusers.utils.testing_utils import require_torch_gpu, torch_device
+
+
+class DummyBlock(torch.nn.Module):
+ def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
+ super().__init__()
+
+ self.proj_in = torch.nn.Linear(in_features, hidden_features)
+ self.activation = torch.nn.ReLU()
+ self.proj_out = torch.nn.Linear(hidden_features, out_features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj_in(x)
+ x = self.activation(x)
+ x = self.proj_out(x)
+ return x
+
+
+class DummyModel(ModelMixin):
+ def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
+ super().__init__()
+
+ self.linear_1 = torch.nn.Linear(in_features, hidden_features)
+ self.activation = torch.nn.ReLU()
+ self.blocks = torch.nn.ModuleList(
+ [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
+ )
+ self.linear_2 = torch.nn.Linear(hidden_features, out_features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear_1(x)
+ x = self.activation(x)
+ for block in self.blocks:
+ x = block(x)
+ x = self.linear_2(x)
+ return x
+
+
+class DummyPipeline(DiffusionPipeline):
+ model_cpu_offload_seq = "model"
+
+ def __init__(self, model: torch.nn.Module) -> None:
+ super().__init__()
+
+ self.register_modules(model=model)
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ for _ in range(2):
+ x = x + 0.1 * self.model(x)
+ return x
+
+
+@require_torch_gpu
+class GroupOffloadTests(unittest.TestCase):
+ in_features = 64
+ hidden_features = 256
+ out_features = 64
+ num_layers = 4
+
+ def setUp(self):
+ with torch.no_grad():
+ self.model = self.get_model()
+ self.input = torch.randn((4, self.in_features)).to(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+
+ del self.model
+ del self.input
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ def get_model(self):
+ torch.manual_seed(0)
+ return DummyModel(
+ in_features=self.in_features,
+ hidden_features=self.hidden_features,
+ out_features=self.out_features,
+ num_layers=self.num_layers,
+ )
+
+ def test_offloading_forward_pass(self):
+ @torch.no_grad()
+ def run_forward(model):
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ self.assertTrue(
+ all(
+ module._diffusers_hook.get_hook("group_offloading") is not None
+ for module in model.modules()
+ if hasattr(module, "_diffusers_hook")
+ )
+ )
+ model.eval()
+ output = model(self.input)[0].cpu()
+ max_memory_allocated = torch.cuda.max_memory_allocated()
+ return output, max_memory_allocated
+
+ self.model.to(torch_device)
+ output_without_group_offloading, mem_baseline = run_forward(self.model)
+ self.model.to("cpu")
+
+ model = self.get_model()
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
+ output_with_group_offloading1, mem1 = run_forward(model)
+
+ model = self.get_model()
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
+ output_with_group_offloading2, mem2 = run_forward(model)
+
+ model = self.get_model()
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
+ output_with_group_offloading3, mem3 = run_forward(model)
+
+ model = self.get_model()
+ model.enable_group_offload(torch_device, offload_type="leaf_level")
+ output_with_group_offloading4, mem4 = run_forward(model)
+
+ model = self.get_model()
+ model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
+ output_with_group_offloading5, mem5 = run_forward(model)
+
+ # Precision assertions - offloading should not impact the output
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
+
+ # Memory assertions - offloading should reduce memory usage
+ self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)
+
+ def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
+ if torch.device(torch_device).type != "cuda":
+ return
+ self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
+ logger = get_logger("diffusers.models.modeling_utils")
+ logger.setLevel("INFO")
+ with self.assertLogs(logger, level="WARNING") as cm:
+ self.model.to(torch_device)
+ self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
+
+ def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
+ if torch.device(torch_device).type != "cuda":
+ return
+ pipe = DummyPipeline(self.model)
+ self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
+ logger = get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel("INFO")
+ with self.assertLogs(logger, level="WARNING") as cm:
+ pipe.to(torch_device)
+ self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
+
+ def test_error_raised_if_streams_used_and_no_cuda_device(self):
+ original_is_available = torch.cuda.is_available
+ torch.cuda.is_available = lambda: False
+ with self.assertRaises(ValueError):
+ self.model.enable_group_offload(
+ onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True
+ )
+ torch.cuda.is_available = original_is_available
+
+ def test_error_raised_if_supports_group_offloading_false(self):
+ self.model._supports_group_offloading = False
+ with self.assertRaisesRegex(ValueError, "does not support group offloading"):
+ self.model.enable_group_offload(onload_device=torch.device("cuda"))
+
+ def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
+ pipe = DummyPipeline(self.model)
+ pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
+ with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
+ pipe.enable_model_cpu_offload()
+
+ def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
+ pipe = DummyPipeline(self.model)
+ pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
+ with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
+ pipe.enable_sequential_cpu_offload()
+
+ def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
+ pipe = DummyPipeline(self.model)
+ pipe.enable_model_cpu_offload()
+ with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
+ pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
+
+ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
+ pipe = DummyPipeline(self.model)
+ pipe.enable_sequential_cpu_offload()
+ with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
+ pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py
new file mode 100644
index 000000000000..74bd43c52315
--- /dev/null
+++ b/tests/hooks/test_hooks.py
@@ -0,0 +1,382 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+
+from diffusers.hooks import HookRegistry, ModelHook
+from diffusers.training_utils import free_memory
+from diffusers.utils.logging import get_logger
+from diffusers.utils.testing_utils import CaptureLogger, torch_device
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+class DummyBlock(torch.nn.Module):
+ def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
+ super().__init__()
+
+ self.proj_in = torch.nn.Linear(in_features, hidden_features)
+ self.activation = torch.nn.ReLU()
+ self.proj_out = torch.nn.Linear(hidden_features, out_features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj_in(x)
+ x = self.activation(x)
+ x = self.proj_out(x)
+ return x
+
+
+class DummyModel(torch.nn.Module):
+ def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
+ super().__init__()
+
+ self.linear_1 = torch.nn.Linear(in_features, hidden_features)
+ self.activation = torch.nn.ReLU()
+ self.blocks = torch.nn.ModuleList(
+ [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
+ )
+ self.linear_2 = torch.nn.Linear(hidden_features, out_features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear_1(x)
+ x = self.activation(x)
+ for block in self.blocks:
+ x = block(x)
+ x = self.linear_2(x)
+ return x
+
+
+class AddHook(ModelHook):
+ def __init__(self, value: int):
+ super().__init__()
+ self.value = value
+
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
+ logger.debug("AddHook pre_forward")
+ args = ((x + self.value) if torch.is_tensor(x) else x for x in args)
+ return args, kwargs
+
+ def post_forward(self, module, output):
+ logger.debug("AddHook post_forward")
+ return output
+
+
+class MultiplyHook(ModelHook):
+ def __init__(self, value: int):
+ super().__init__()
+ self.value = value
+
+ def pre_forward(self, module, *args, **kwargs):
+ logger.debug("MultiplyHook pre_forward")
+ args = ((x * self.value) if torch.is_tensor(x) else x for x in args)
+ return args, kwargs
+
+ def post_forward(self, module, output):
+ logger.debug("MultiplyHook post_forward")
+ return output
+
+ def __repr__(self):
+ return f"MultiplyHook(value={self.value})"
+
+
+class StatefulAddHook(ModelHook):
+ _is_stateful = True
+
+ def __init__(self, value: int):
+ super().__init__()
+ self.value = value
+ self.increment = 0
+
+ def pre_forward(self, module, *args, **kwargs):
+ logger.debug("StatefulAddHook pre_forward")
+ add_value = self.value + self.increment
+ self.increment += 1
+ args = ((x + add_value) if torch.is_tensor(x) else x for x in args)
+ return args, kwargs
+
+ def reset_state(self, module):
+ self.increment = 0
+
+
+class SkipLayerHook(ModelHook):
+ def __init__(self, skip_layer: bool):
+ super().__init__()
+ self.skip_layer = skip_layer
+
+ def pre_forward(self, module, *args, **kwargs):
+ logger.debug("SkipLayerHook pre_forward")
+ return args, kwargs
+
+ def new_forward(self, module, *args, **kwargs):
+ logger.debug("SkipLayerHook new_forward")
+ if self.skip_layer:
+ return args[0]
+ return self.fn_ref.original_forward(*args, **kwargs)
+
+ def post_forward(self, module, output):
+ logger.debug("SkipLayerHook post_forward")
+ return output
+
+
+class HookTests(unittest.TestCase):
+ in_features = 4
+ hidden_features = 8
+ out_features = 4
+ num_layers = 2
+
+ def setUp(self):
+ params = self.get_module_parameters()
+ self.model = DummyModel(**params)
+ self.model.to(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+
+ del self.model
+ gc.collect()
+ free_memory()
+
+ def get_module_parameters(self):
+ return {
+ "in_features": self.in_features,
+ "hidden_features": self.hidden_features,
+ "out_features": self.out_features,
+ "num_layers": self.num_layers,
+ }
+
+ def get_generator(self):
+ return torch.manual_seed(0)
+
+ def test_hook_registry(self):
+ registry = HookRegistry.check_if_exists_or_initialize(self.model)
+ registry.register_hook(AddHook(1), "add_hook")
+ registry.register_hook(MultiplyHook(2), "multiply_hook")
+
+ registry_repr = repr(registry)
+ expected_repr = (
+ "HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")"
+ )
+
+ self.assertEqual(len(registry.hooks), 2)
+ self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
+ self.assertEqual(registry_repr, expected_repr)
+
+ registry.remove_hook("add_hook")
+
+ self.assertEqual(len(registry.hooks), 1)
+ self.assertEqual(registry._hook_order, ["multiply_hook"])
+
+ def test_stateful_hook(self):
+ registry = HookRegistry.check_if_exists_or_initialize(self.model)
+ registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
+
+ self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
+
+ input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
+ num_repeats = 3
+
+ for i in range(num_repeats):
+ result = self.model(input)
+ if i == 0:
+ output1 = result
+
+ self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
+
+ registry.reset_stateful_hooks()
+ output2 = self.model(input)
+
+ self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
+ self.assertTrue(torch.allclose(output1, output2))
+
+ def test_inference(self):
+ registry = HookRegistry.check_if_exists_or_initialize(self.model)
+ registry.register_hook(AddHook(1), "add_hook")
+ registry.register_hook(MultiplyHook(2), "multiply_hook")
+
+ input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
+ output1 = self.model(input).mean().detach().cpu().item()
+
+ registry.remove_hook("multiply_hook")
+ new_input = input * 2
+ output2 = self.model(new_input).mean().detach().cpu().item()
+
+ registry.remove_hook("add_hook")
+ new_input = input * 2 + 1
+ output3 = self.model(new_input).mean().detach().cpu().item()
+
+ self.assertAlmostEqual(output1, output2, places=5)
+ self.assertAlmostEqual(output1, output3, places=5)
+
+ def test_skip_layer_hook(self):
+ registry = HookRegistry.check_if_exists_or_initialize(self.model)
+ registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
+
+ input = torch.zeros(1, 4, device=torch_device)
+ output = self.model(input).mean().detach().cpu().item()
+ self.assertEqual(output, 0.0)
+
+ registry.remove_hook("skip_layer_hook")
+ registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
+ output = self.model(input).mean().detach().cpu().item()
+ self.assertNotEqual(output, 0.0)
+
+ def test_skip_layer_internal_block(self):
+ registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
+ input = torch.zeros(1, 4, device=torch_device)
+
+ registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
+ with self.assertRaises(RuntimeError) as cm:
+ self.model(input).mean().detach().cpu().item()
+ self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
+
+ registry.remove_hook("skip_layer_hook")
+ output = self.model(input).mean().detach().cpu().item()
+ self.assertNotEqual(output, 0.0)
+
+ registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
+ registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
+ output = self.model(input).mean().detach().cpu().item()
+ self.assertNotEqual(output, 0.0)
+
+ def test_invocation_order_stateful_first(self):
+ registry = HookRegistry.check_if_exists_or_initialize(self.model)
+ registry.register_hook(StatefulAddHook(1), "add_hook")
+ registry.register_hook(AddHook(2), "add_hook_2")
+ registry.register_hook(MultiplyHook(3), "multiply_hook")
+
+ input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
+
+ logger = get_logger(__name__)
+ logger.setLevel("DEBUG")
+
+ with CaptureLogger(logger) as cap_logger:
+ self.model(input)
+ output = cap_logger.out.replace(" ", "").replace("\n", "")
+ expected_invocation_order_log = (
+ (
+ "MultiplyHook pre_forward\n"
+ "AddHook pre_forward\n"
+ "StatefulAddHook pre_forward\n"
+ "AddHook post_forward\n"
+ "MultiplyHook post_forward\n"
+ )
+ .replace(" ", "")
+ .replace("\n", "")
+ )
+ self.assertEqual(output, expected_invocation_order_log)
+
+ registry.remove_hook("add_hook")
+ with CaptureLogger(logger) as cap_logger:
+ self.model(input)
+ output = cap_logger.out.replace(" ", "").replace("\n", "")
+ expected_invocation_order_log = (
+ (
+ "MultiplyHook pre_forward\n"
+ "AddHook pre_forward\n"
+ "AddHook post_forward\n"
+ "MultiplyHook post_forward\n"
+ )
+ .replace(" ", "")
+ .replace("\n", "")
+ )
+ self.assertEqual(output, expected_invocation_order_log)
+
+ def test_invocation_order_stateful_middle(self):
+ registry = HookRegistry.check_if_exists_or_initialize(self.model)
+ registry.register_hook(AddHook(2), "add_hook")
+ registry.register_hook(StatefulAddHook(1), "add_hook_2")
+ registry.register_hook(MultiplyHook(3), "multiply_hook")
+
+ input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
+
+ logger = get_logger(__name__)
+ logger.setLevel("DEBUG")
+
+ with CaptureLogger(logger) as cap_logger:
+ self.model(input)
+ output = cap_logger.out.replace(" ", "").replace("\n", "")
+ expected_invocation_order_log = (
+ (
+ "MultiplyHook pre_forward\n"
+ "StatefulAddHook pre_forward\n"
+ "AddHook pre_forward\n"
+ "AddHook post_forward\n"
+ "MultiplyHook post_forward\n"
+ )
+ .replace(" ", "")
+ .replace("\n", "")
+ )
+ self.assertEqual(output, expected_invocation_order_log)
+
+ registry.remove_hook("add_hook")
+ with CaptureLogger(logger) as cap_logger:
+ self.model(input)
+ output = cap_logger.out.replace(" ", "").replace("\n", "")
+ expected_invocation_order_log = (
+ ("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n")
+ .replace(" ", "")
+ .replace("\n", "")
+ )
+ self.assertEqual(output, expected_invocation_order_log)
+
+ registry.remove_hook("add_hook_2")
+ with CaptureLogger(logger) as cap_logger:
+ self.model(input)
+ output = cap_logger.out.replace(" ", "").replace("\n", "")
+ expected_invocation_order_log = (
+ ("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
+ )
+ self.assertEqual(output, expected_invocation_order_log)
+
+ def test_invocation_order_stateful_last(self):
+ registry = HookRegistry.check_if_exists_or_initialize(self.model)
+ registry.register_hook(AddHook(1), "add_hook")
+ registry.register_hook(MultiplyHook(2), "multiply_hook")
+ registry.register_hook(StatefulAddHook(3), "add_hook_2")
+
+ input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
+
+ logger = get_logger(__name__)
+ logger.setLevel("DEBUG")
+
+ with CaptureLogger(logger) as cap_logger:
+ self.model(input)
+ output = cap_logger.out.replace(" ", "").replace("\n", "")
+ expected_invocation_order_log = (
+ (
+ "StatefulAddHook pre_forward\n"
+ "MultiplyHook pre_forward\n"
+ "AddHook pre_forward\n"
+ "AddHook post_forward\n"
+ "MultiplyHook post_forward\n"
+ )
+ .replace(" ", "")
+ .replace("\n", "")
+ )
+ self.assertEqual(output, expected_invocation_order_log)
+
+ registry.remove_hook("add_hook")
+ with CaptureLogger(logger) as cap_logger:
+ self.model(input)
+ output = cap_logger.out.replace(" ", "").replace("\n", "")
+ expected_invocation_order_log = (
+ ("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n")
+ .replace(" ", "")
+ .replace("\n", "")
+ )
+ self.assertEqual(output, expected_invocation_order_log)
diff --git a/tests/lora/test_deprecated_utilities.py b/tests/lora/test_deprecated_utilities.py
new file mode 100644
index 000000000000..4275ef8089a3
--- /dev/null
+++ b/tests/lora/test_deprecated_utilities.py
@@ -0,0 +1,39 @@
+import os
+import tempfile
+import unittest
+
+import torch
+
+from diffusers.loaders.lora_base import LoraBaseMixin
+
+
+class UtilityMethodDeprecationTests(unittest.TestCase):
+ def test_fetch_state_dict_cls_method_raises_warning(self):
+ state_dict = torch.nn.Linear(3, 3).state_dict()
+ with self.assertWarns(FutureWarning) as warning:
+ _ = LoraBaseMixin._fetch_state_dict(
+ state_dict,
+ weight_name=None,
+ use_safetensors=False,
+ local_files_only=True,
+ cache_dir=None,
+ force_download=False,
+ proxies=None,
+ token=None,
+ revision=None,
+ subfolder=None,
+ user_agent=None,
+ allow_pickle=None,
+ )
+ warning_message = str(warning.warnings[0].message)
+ assert "Using the `_fetch_state_dict()` method from" in warning_message
+
+ def test_best_guess_weight_name_cls_method_raises_warning(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ state_dict = torch.nn.Linear(3, 3).state_dict()
+ torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
+
+ with self.assertWarns(FutureWarning) as warning:
+ _ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
+ warning_message = str(warning.warnings[0].message)
+ assert "Using the `_best_guess_weight_name()` method from" in warning_message
diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py
index c141ebc96b3e..dc2695452c2f 100644
--- a/tests/lora/test_lora_layers_cogvideox.py
+++ b/tests/lora/test_lora_layers_cogvideox.py
@@ -15,7 +15,6 @@
import sys
import unittest
-import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -28,19 +27,13 @@
)
from diffusers.utils.testing_utils import (
floats_tensor,
- is_peft_available,
require_peft_backend,
- skip_mps,
- torch_device,
)
-if is_peft_available():
- pass
-
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -125,36 +118,6 @@ def get_dummy_inputs(self, with_generator=True):
return noise, input_ids, pipeline_inputs
- @skip_mps
- def test_lora_fuse_nan(self):
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
-
- self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
-
- # corrupt one LoRA weight with `inf` values
- with torch.no_grad():
- pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
-
- # with `safe_fusing=True` we should see an Error
- with self.assertRaises(ValueError):
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
-
- # without we should not see an error, but every image will be black
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
-
- out = pipe(
- "test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
- )[0]
-
- self.assertTrue(np.isnan(out).all())
-
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
@@ -192,3 +155,7 @@ def test_simple_inference_with_text_lora_fused(self):
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+
+ @unittest.skip("Not supported in CogVideoX.")
+ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
+ pass
diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py
new file mode 100644
index 000000000000..178de2069b7e
--- /dev/null
+++ b/tests/lora/test_lora_layers_cogview4.py
@@ -0,0 +1,174 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, GlmModel
+
+from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
+from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+class TokenizerWrapper:
+ @staticmethod
+ def from_pretrained(*args, **kwargs):
+ return AutoTokenizer.from_pretrained(
+ "hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True
+ )
+
+
+@require_peft_backend
+@skip_mps
+class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = CogView4Pipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": 2,
+ "in_channels": 4,
+ "num_layers": 2,
+ "attention_head_dim": 4,
+ "num_attention_heads": 4,
+ "out_channels": 4,
+ "text_embed_dim": 32,
+ "time_embed_dim": 8,
+ "condition_dim": 4,
+ }
+ transformer_cls = CogView4Transformer2DModel
+ vae_kwargs = {
+ "block_out_channels": [32, 64],
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ "latent_channels": 4,
+ "sample_size": 128,
+ }
+ vae_cls = AutoencoderKL
+ tokenizer_cls, tokenizer_id, tokenizer_subfolder = (
+ TokenizerWrapper,
+ "hf-internal-testing/tiny-random-cogview4",
+ "tokenizer",
+ )
+ text_encoder_cls, text_encoder_id, text_encoder_subfolder = (
+ GlmModel,
+ "hf-internal-testing/tiny-random-cogview4",
+ "text_encoder",
+ )
+
+ @property
+ def output_shape(self):
+ return (1, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ sizes = (4, 4)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "",
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ def test_simple_inference_save_pretrained(self):
+ """
+ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
+ """
+ for scheduler_cls in self.scheduler_classes:
+ components, _, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
+
+ pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
+ pipe_from_pretrained.to(torch_device)
+
+ images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
+
+ @unittest.skip("Not supported in CogView4.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in CogView4.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in CogView4.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in CogView4.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in CogView4.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in CogView4.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in CogView4.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in CogView4.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
index 3bc46d1e9b13..860aa6511689 100644
--- a/tests/lora/test_lora_layers_flux.py
+++ b/tests/lora/test_lora_layers_flux.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import gc
import os
import sys
@@ -19,15 +20,22 @@
import unittest
import numpy as np
+import pytest
import safetensors.torch
import torch
+from parameterized import parameterized
+from PIL import Image
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
-from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
+from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel
+from diffusers.utils import load_image, logging
from diffusers.utils.testing_utils import (
+ CaptureLogger,
floats_tensor,
is_peft_available,
+ nightly,
numpy_cosine_similarity_distance,
+ require_big_gpu_with_torch_cuda,
require_peft_backend,
require_torch_gpu,
slow,
@@ -155,6 +163,109 @@ def test_with_alpha_in_state_dict(self):
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
+ def test_lora_expansion_works_for_absent_keys(self):
+ components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # Modify the config to have a layer which won't be present in the second LoRA we will load.
+ modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
+ modified_denoiser_lora_config.target_modules.add("x_embedder")
+
+ pipe.transformer.add_adapter(modified_denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
+
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
+ "LoRA should lead to different results.",
+ )
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
+ self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
+
+ # Modify the state dict to exclude "x_embedder" related LoRA params.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
+
+ pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
+ pipe.set_adapters(["one", "two"])
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
+ images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertFalse(
+ np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
+ "Different LoRAs should lead to different results.",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
+ "LoRA should lead to different results.",
+ )
+
+ def test_lora_expansion_works_for_extra_keys(self):
+ components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # Modify the config to have a layer which won't be present in the first LoRA we will load.
+ modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
+ modified_denoiser_lora_config.target_modules.add("x_embedder")
+
+ pipe.transformer.add_adapter(modified_denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
+
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
+ "LoRA should lead to different results.",
+ )
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
+ self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ pipe.unload_lora_weights()
+ # Modify the state dict to exclude "x_embedder" related LoRA params.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
+ pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
+
+ # Load state dict with `x_embedder`.
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
+
+ pipe.set_adapters(["one", "two"])
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
+ images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertFalse(
+ np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
+ "Different LoRAs should lead to different results.",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
+ "LoRA should lead to different results.",
+ )
+
+ @unittest.skip("Not supported in Flux.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@@ -163,12 +274,545 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass
+ @unittest.skip("Not supported in Flux.")
+ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
+ pass
+
+
+class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = FluxControlPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler()
+ scheduler_kwargs = {}
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ transformer_kwargs = {
+ "patch_size": 1,
+ "in_channels": 8,
+ "out_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "pooled_projection_dim": 32,
+ "axes_dims_rope": [4, 4, 8],
+ }
+ transformer_cls = FluxTransformer2DModel
+ vae_kwargs = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "block_out_channels": (4,),
+ "layers_per_block": 1,
+ "latent_channels": 1,
+ "norm_num_groups": 1,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "shift_factor": 0.0609,
+ "scaling_factor": 1.5035,
+ }
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id = CLIPTokenizer, "peft-internal-testing/tiny-clip-text-2"
+ tokenizer_2_cls, tokenizer_2_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = CLIPTextModel, "peft-internal-testing/tiny-clip-text-2"
+ text_encoder_2_cls, text_encoder_2_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ @property
+ def output_shape(self):
+ return (1, 8, 8, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 10
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
+ "num_inference_steps": 4,
+ "guidance_scale": 0.0,
+ "height": 8,
+ "width": 8,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_with_norm_in_state_dict(self):
+ components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.INFO)
+
+ original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]:
+ norm_state_dict = {}
+ for name, module in pipe.transformer.named_modules():
+ if norm_layer not in name or not hasattr(module, "weight") or module.weight is None:
+ continue
+ norm_state_dict[f"transformer.{name}.weight"] = torch.randn(
+ module.weight.shape, device=module.weight.device, dtype=module.weight.dtype
+ )
+
+ with CaptureLogger(logger) as cap_logger:
+ pipe.load_lora_weights(norm_state_dict)
+ lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ "The provided state dict contains normalization layers in addition to LoRA layers"
+ in cap_logger.out
+ )
+ self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
+
+ pipe.unload_lora_weights()
+ lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(pipe.transformer._transformer_norm_layers is None)
+ self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5))
+ self.assertFalse(
+ np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested"
+ )
+
+ with CaptureLogger(logger) as cap_logger:
+ for key in list(norm_state_dict.keys()):
+ norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key)
+ pipe.load_lora_weights(norm_state_dict)
+
+ self.assertTrue(
+ "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
+ )
+
+ def test_lora_parameter_expanded_shapes(self):
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.DEBUG)
+
+ # Change the transformer config to mimic a real use case.
+ num_channels_without_control = 4
+ transformer = FluxTransformer2DModel.from_config(
+ components["transformer"].config, in_channels=num_channels_without_control
+ ).to(torch_device)
+ self.assertTrue(
+ transformer.config.in_channels == num_channels_without_control,
+ f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
+ )
+
+ original_transformer_state_dict = pipe.transformer.state_dict()
+ x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
+ incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
+ self.assertTrue(
+ "x_embedder.weight" in incompatible_keys.missing_keys,
+ "Could not find x_embedder.weight in the missing keys.",
+ )
+ transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
+ pipe.transformer = transformer
+
+ out_features, in_features = pipe.transformer.x_embedder.weight.shape
+ rank = 4
+
+ dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
+ dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
+ }
+ with CaptureLogger(logger) as cap_logger:
+ pipe.load_lora_weights(lora_state_dict, "adapter-1")
+
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
+ self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
+ self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
+
+ # Testing opposite direction where the LoRA params are zero-padded.
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
+ dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
+ }
+ with CaptureLogger(logger) as cap_logger:
+ pipe.load_lora_weights(lora_state_dict, "adapter-1")
+
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
+ self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
+ self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
+
+ def test_normal_lora_with_expanded_lora_raises_error(self):
+ # Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
+ # load shape expanded LoRA (such as Control LoRA).
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+
+ # Change the transformer config to mimic a real use case.
+ num_channels_without_control = 4
+ transformer = FluxTransformer2DModel.from_config(
+ components["transformer"].config, in_channels=num_channels_without_control
+ ).to(torch_device)
+ components["transformer"] = transformer
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.DEBUG)
+
+ out_features, in_features = pipe.transformer.x_embedder.weight.shape
+ rank = 4
+
+ shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
+ shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
+ }
+ with CaptureLogger(logger) as cap_logger:
+ pipe.load_lora_weights(lora_state_dict, "adapter-1")
+
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+ self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
+ self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
+ self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
+ normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
+ }
+
+ with CaptureLogger(logger) as cap_logger:
+ pipe.load_lora_weights(lora_state_dict, "adapter-2")
+
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+ self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
+ self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
+
+ lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
+
+ # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
+ # This should raise a runtime error on input shapes being incompatible.
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ # Change the transformer config to mimic a real use case.
+ num_channels_without_control = 4
+ transformer = FluxTransformer2DModel.from_config(
+ components["transformer"].config, in_channels=num_channels_without_control
+ ).to(torch_device)
+ components["transformer"] = transformer
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.DEBUG)
+
+ out_features, in_features = pipe.transformer.x_embedder.weight.shape
+ rank = 4
+
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
+ }
+ pipe.load_lora_weights(lora_state_dict, "adapter-1")
+
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
+ self.assertTrue(pipe.transformer.config.in_channels == in_features)
+
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
+ }
+
+ # We should check for input shapes being incompatible here. But because above mentioned issue is
+ # not a supported use case, and because of the PEFT renaming, we will currently have a shape
+ # mismatch error.
+ self.assertRaisesRegex(
+ RuntimeError,
+ "size mismatch for x_embedder.lora_A.adapter-2.weight",
+ pipe.load_lora_weights,
+ lora_state_dict,
+ "adapter-2",
+ )
+
+ def test_fuse_expanded_lora_with_regular_lora(self):
+ # This test checks if it works when a lora with expanded shapes (like control loras) but
+ # another lora with correct shapes is loaded. The opposite direction isn't supported and is
+ # tested with it.
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+
+ # Change the transformer config to mimic a real use case.
+ num_channels_without_control = 4
+ transformer = FluxTransformer2DModel.from_config(
+ components["transformer"].config, in_channels=num_channels_without_control
+ ).to(torch_device)
+ components["transformer"] = transformer
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.DEBUG)
+
+ out_features, in_features = pipe.transformer.x_embedder.weight.shape
+ rank = 4
+
+ shape_expander_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
+ shape_expander_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
+ }
+ pipe.load_lora_weights(lora_state_dict, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
+ normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
+ }
+
+ pipe.load_lora_weights(lora_state_dict, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
+ lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
+ self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
+ self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
+
+ pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
+ lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
+
+ def test_load_regular_lora(self):
+ # This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
+ # into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
+ # transformers include Flux Fill, Flux Control, etc.
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ out_features, in_features = pipe.transformer.x_embedder.weight.shape
+ rank = 4
+ in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
+ normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
+ normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
+ }
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.INFO)
+ with CaptureLogger(logger) as cap_logger:
+ pipe.load_lora_weights(lora_state_dict, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
+ self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
+
+ def test_lora_unload_with_parameter_expanded_shapes(self):
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.DEBUG)
+
+ # Change the transformer config to mimic a real use case.
+ num_channels_without_control = 4
+ transformer = FluxTransformer2DModel.from_config(
+ components["transformer"].config, in_channels=num_channels_without_control
+ ).to(torch_device)
+ self.assertTrue(
+ transformer.config.in_channels == num_channels_without_control,
+ f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
+ )
+
+ # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
+ components["transformer"] = transformer
+ pipe = FluxPipeline(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ control_image = inputs.pop("control_image")
+ original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ control_pipe = self.pipeline_class(**components)
+ out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
+ rank = 4
+
+ dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
+ dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
+ }
+ with CaptureLogger(logger) as cap_logger:
+ control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ inputs["control_image"] = control_image
+ lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
+ self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
+ self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
+
+ control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
+ self.assertTrue(
+ control_pipe.transformer.config.in_channels == num_channels_without_control,
+ f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
+ )
+ loaded_pipe = FluxPipeline.from_pipe(control_pipe)
+ self.assertTrue(
+ loaded_pipe.transformer.config.in_channels == num_channels_without_control,
+ f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
+ )
+ inputs.pop("control_image")
+ unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
+ self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
+ self.assertTrue(pipe.transformer.config.in_channels == in_features)
+
+ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.DEBUG)
+
+ # Change the transformer config to mimic a real use case.
+ num_channels_without_control = 4
+ transformer = FluxTransformer2DModel.from_config(
+ components["transformer"].config, in_channels=num_channels_without_control
+ ).to(torch_device)
+ self.assertTrue(
+ transformer.config.in_channels == num_channels_without_control,
+ f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
+ )
+
+ # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
+ components["transformer"] = transformer
+ pipe = FluxPipeline(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ control_image = inputs.pop("control_image")
+ original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ control_pipe = self.pipeline_class(**components)
+ out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
+ rank = 4
+
+ dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
+ dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
+ }
+ with CaptureLogger(logger) as cap_logger:
+ control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ inputs["control_image"] = control_image
+ lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
+ self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
+ self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
+
+ control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
+ self.assertTrue(
+ control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
+ f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
+ )
+ no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
+ self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
+
+ @unittest.skip("Not supported in Flux.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Flux.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Flux.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Not supported in Flux.")
+ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
+ pass
+
@slow
+@nightly
@require_torch_gpu
@require_peft_backend
-# @unittest.skip("We cannot run inference on this model with the current CI hardware")
-# TODO (DN6, sayakpaul): move these tests to a beefier GPU
+@require_big_gpu_with_torch_cuda
+@pytest.mark.big_gpu_with_torch_cuda
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
@@ -190,6 +834,7 @@ def setUp(self):
def tearDown(self):
super().tearDown()
+ del self.pipeline
gc.collect()
torch.cuda.empty_cache()
@@ -197,7 +842,10 @@ def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
- self.pipeline.enable_model_cpu_offload()
+ # Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI
+ # run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
+ # `enable_model_cpu_offload()`. We repeat this for the other tests, too.
+ self.pipeline = self.pipeline.to(torch_device)
prompt = "jon snow eating pizza with ketchup"
@@ -219,7 +867,7 @@ def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
- self.pipeline.enable_model_cpu_offload()
+ self.pipeline = self.pipeline.to(torch_device)
prompt = "The cat with a brain slug earring"
out = self.pipeline(
@@ -241,7 +889,7 @@ def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
- self.pipeline.enable_model_cpu_offload()
+ self.pipeline = self.pipeline.to(torch_device)
prompt = "optimus is cleaning the house with broomstick"
out = self.pipeline(
@@ -263,7 +911,7 @@ def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
- self.pipeline.enable_model_cpu_offload()
+ self.pipeline = self.pipeline.to(torch_device)
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
@@ -280,3 +928,128 @@ def test_flux_xlabs(self):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
+
+ def test_flux_xlabs_load_lora_with_single_blocks(self):
+ self.pipeline.load_lora_weights(
+ "salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
+ )
+ self.pipeline.fuse_lora()
+ self.pipeline.unload_lora_weights()
+ self.pipeline.enable_model_cpu_offload()
+
+ prompt = "a wizard mouse playing chess"
+
+ out = self.pipeline(
+ prompt,
+ num_inference_steps=self.num_inference_steps,
+ guidance_scale=3.5,
+ output_type="np",
+ generator=torch.manual_seed(self.seed),
+ ).images
+ out_slice = out[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array(
+ [0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
+
+ assert max_diff < 1e-3
+
+
+@nightly
+@require_torch_gpu
+@require_peft_backend
+@require_big_gpu_with_torch_cuda
+@pytest.mark.big_gpu_with_torch_cuda
+class FluxControlLoRAIntegrationTests(unittest.TestCase):
+ num_inference_steps = 10
+ seed = 0
+ prompt = "A robot made of exotic candies and chocolates of different kinds."
+
+ def setUp(self):
+ super().setUp()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ self.pipeline = FluxControlPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+ ).to("cuda")
+
+ def tearDown(self):
+ super().tearDown()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
+ def test_lora(self, lora_ckpt_id):
+ self.pipeline.load_lora_weights(lora_ckpt_id)
+ self.pipeline.fuse_lora()
+ self.pipeline.unload_lora_weights()
+
+ if "Canny" in lora_ckpt_id:
+ control_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
+ )
+ else:
+ control_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
+ )
+
+ image = self.pipeline(
+ prompt=self.prompt,
+ control_image=control_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=self.num_inference_steps,
+ guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
+ output_type="np",
+ generator=torch.manual_seed(self.seed),
+ ).images
+
+ out_slice = image[0, -3:, -3:, -1].flatten()
+ if "Canny" in lora_ckpt_id:
+ expected_slice = np.array([0.8438, 0.8438, 0.8438, 0.8438, 0.8438, 0.8398, 0.8438, 0.8438, 0.8516])
+ else:
+ expected_slice = np.array([0.8203, 0.8320, 0.8359, 0.8203, 0.8281, 0.8281, 0.8203, 0.8242, 0.8359])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
+
+ assert max_diff < 1e-3
+
+ @parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
+ def test_lora_with_turbo(self, lora_ckpt_id):
+ self.pipeline.load_lora_weights(lora_ckpt_id)
+ self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
+ self.pipeline.fuse_lora()
+ self.pipeline.unload_lora_weights()
+
+ if "Canny" in lora_ckpt_id:
+ control_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
+ )
+ else:
+ control_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
+ )
+
+ image = self.pipeline(
+ prompt=self.prompt,
+ control_image=control_image,
+ height=1024,
+ width=1024,
+ num_inference_steps=self.num_inference_steps,
+ guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
+ output_type="np",
+ generator=torch.manual_seed(self.seed),
+ ).images
+
+ out_slice = image[0, -3:, -3:, -1].flatten()
+ if "Canny" in lora_ckpt_id:
+ expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484])
+ else:
+ expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
+
+ assert max_diff < 1e-3
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
new file mode 100644
index 000000000000..d2015d8b0711
--- /dev/null
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -0,0 +1,257 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import sys
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideoPipeline,
+ HunyuanVideoTransformer3DModel,
+)
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_big_gpu_with_torch_cuda,
+ require_peft_backend,
+ require_torch_gpu,
+ skip_mps,
+)
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = HunyuanVideoPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 10,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "num_refiner_layers": 1,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "guidance_embeds": True,
+ "text_embed_dim": 16,
+ "pooled_projection_dim": 8,
+ "rope_axes_dim": (2, 4, 4),
+ }
+ transformer_cls = HunyuanVideoTransformer3DModel
+ vae_kwargs = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 4,
+ "down_block_types": (
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ "up_block_types": (
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ "block_out_channels": (8, 8, 8, 8),
+ "layers_per_block": 1,
+ "act_fn": "silu",
+ "norm_num_groups": 4,
+ "scaling_factor": 0.476986,
+ "spatial_compression_ratio": 8,
+ "temporal_compression_ratio": 4,
+ "mid_block_add_attention": True,
+ }
+ vae_cls = AutoencoderKLHunyuanVideo
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id, tokenizer_subfolder = (
+ LlamaTokenizerFast,
+ "hf-internal-testing/tiny-random-hunyuanvideo",
+ "tokenizer",
+ )
+ tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = (
+ CLIPTokenizer,
+ "hf-internal-testing/tiny-random-hunyuanvideo",
+ "tokenizer_2",
+ )
+ text_encoder_cls, text_encoder_id, text_encoder_subfolder = (
+ LlamaModel,
+ "hf-internal-testing/tiny-random-hunyuanvideo",
+ "text_encoder",
+ )
+ text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = (
+ CLIPTextModel,
+ "hf-internal-testing/tiny-random-hunyuanvideo",
+ "text_encoder_2",
+ )
+
+ @property
+ def output_shape(self):
+ return (1, 9, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ sizes = (4, 4)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "",
+ "num_frames": num_frames,
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": sequence_length,
+ "prompt_template": {"template": "{}", "crop_start": 0},
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ # TODO(aryan): Fix the following test
+ @unittest.skip("This test fails with an error I haven't been able to debug yet.")
+ def test_simple_inference_save_pretrained(self):
+ pass
+
+ @unittest.skip("Not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in HunyuanVideo.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
+
+
+@nightly
+@require_torch_gpu
+@require_peft_backend
+@require_big_gpu_with_torch_cuda
+@pytest.mark.big_gpu_with_torch_cuda
+class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
+ """internal note: The integration slices were obtained on DGX.
+
+ torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
+ assertions to pass.
+ """
+
+ num_inference_steps = 10
+ seed = 0
+
+ def setUp(self):
+ super().setUp()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ model_id = "hunyuanvideo-community/HunyuanVideo"
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+ )
+ self.pipeline = HunyuanVideoPipeline.from_pretrained(
+ model_id, transformer=transformer, torch_dtype=torch.float16
+ ).to("cuda")
+
+ def tearDown(self):
+ super().tearDown()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_original_format_cseti(self):
+ self.pipeline.load_lora_weights(
+ "Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
+ )
+ self.pipeline.fuse_lora()
+ self.pipeline.unload_lora_weights()
+ self.pipeline.vae.enable_tiling()
+
+ prompt = "CSETIARCANE. A cat walks on the grass, realistic"
+
+ out = self.pipeline(
+ prompt=prompt,
+ height=320,
+ width=512,
+ num_frames=9,
+ num_inference_steps=self.num_inference_steps,
+ output_type="np",
+ generator=torch.manual_seed(self.seed),
+ ).frames[0]
+ out = out.flatten()
+ out_slice = np.concatenate((out[:8], out[-8:]))
+
+ # fmt: off
+ expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
+ # fmt: on
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
+
+ assert max_diff < 1e-3
diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py
new file mode 100644
index 000000000000..0eccaa73ad42
--- /dev/null
+++ b/tests/lora/test_lora_layers_ltx_video.py
@@ -0,0 +1,147 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLLTXVideo,
+ FlowMatchEulerDiscreteScheduler,
+ LTXPipeline,
+ LTXVideoTransformer3DModel,
+)
+from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = LTXPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "in_channels": 8,
+ "out_channels": 8,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "num_attention_heads": 4,
+ "attention_head_dim": 8,
+ "cross_attention_dim": 32,
+ "num_layers": 1,
+ "caption_channels": 32,
+ }
+ transformer_cls = LTXVideoTransformer3DModel
+ vae_kwargs = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 8,
+ "block_out_channels": (8, 8, 8, 8),
+ "decoder_block_out_channels": (8, 8, 8, 8),
+ "layers_per_block": (1, 1, 1, 1, 1),
+ "decoder_layers_per_block": (1, 1, 1, 1, 1),
+ "spatio_temporal_scaling": (True, True, False, False),
+ "decoder_spatio_temporal_scaling": (True, True, False, False),
+ "decoder_inject_noise": (False, False, False, False, False),
+ "upsample_residual": (False, False, False, False),
+ "upsample_factor": (1, 1, 1, 1),
+ "timestep_conditioning": False,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ }
+ vae_cls = AutoencoderKLLTXVideo
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+
+ @property
+ def output_shape(self):
+ return (1, 9, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 8
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ latent_height = 8
+ latent_width = 8
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width))
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "dance monkey",
+ "num_frames": num_frames,
+ "num_inference_steps": 4,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in LTXVideo.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in LTXVideo.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in LTXVideo.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py
new file mode 100644
index 000000000000..07b1cda2f79f
--- /dev/null
+++ b/tests/lora/test_lora_layers_lumina2.py
@@ -0,0 +1,172 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoTokenizer, GemmaForCausalLM
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ Lumina2Text2ImgPipeline,
+ Lumina2Transformer2DModel,
+)
+from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+
+
+@require_peft_backend
+class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = Lumina2Text2ImgPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "sample_size": 4,
+ "patch_size": 2,
+ "in_channels": 4,
+ "hidden_size": 8,
+ "num_layers": 2,
+ "num_attention_heads": 1,
+ "num_kv_heads": 1,
+ "multiple_of": 16,
+ "ffn_dim_multiplier": None,
+ "norm_eps": 1e-5,
+ "scaling_factor": 1.0,
+ "axes_dim_rope": [4, 2, 2],
+ "cap_feat_dim": 8,
+ }
+ transformer_cls = Lumina2Transformer2DModel
+ vae_kwargs = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "block_out_channels": (4,),
+ "layers_per_block": 1,
+ "latent_channels": 4,
+ "norm_num_groups": 1,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "shift_factor": 0.0609,
+ "scaling_factor": 1.5035,
+ }
+ vae_cls = AutoencoderKL
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma"
+ text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers"
+
+ @property
+ def output_shape(self):
+ return (1, 4, 4, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ @unittest.skip("Not supported in Lumina2.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Lumina2.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Lumina2.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Lumina2.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
+
+ @skip_mps
+ @pytest.mark.xfail(
+ condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
+ reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
+ strict=False,
+ )
+ def test_lora_fuse_nan(self):
+ for scheduler_cls in self.scheduler_classes:
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ )
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
+
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
+
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
+ out = pipe(**inputs)[0]
+
+ self.assertTrue(np.isnan(out).all())
diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py
new file mode 100644
index 000000000000..671f1277f99f
--- /dev/null
+++ b/tests/lora/test_lora_layers_mochi.py
@@ -0,0 +1,142 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ require_peft_backend,
+ skip_mps,
+)
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = MochiPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": 2,
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "num_layers": 2,
+ "pooled_projection_dim": 16,
+ "in_channels": 12,
+ "out_channels": None,
+ "qk_norm": "rms_norm",
+ "text_embed_dim": 32,
+ "time_embed_dim": 4,
+ "activation_fn": "swiglu",
+ "max_sequence_length": 16,
+ }
+ transformer_cls = MochiTransformer3DModel
+ vae_kwargs = {
+ "latent_channels": 12,
+ "out_channels": 3,
+ "encoder_block_out_channels": (32, 32, 32, 32),
+ "decoder_block_out_channels": (32, 32, 32, 32),
+ "layers_per_block": (1, 1, 1, 1, 1),
+ }
+ vae_cls = AutoencoderKLMochi
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+
+ @property
+ def output_shape(self):
+ return (1, 7, 16, 16, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 7
+ num_latent_frames = 3
+ sizes = (2, 2)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "dance monkey",
+ "num_frames": num_frames,
+ "num_inference_steps": 4,
+ "guidance_scale": 6.0,
+ # Cannot reduce because convolution kernel becomes bigger than sample
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in Mochi.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Mochi.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Mochi.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Mochi.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
+
+ @unittest.skip("Not supported in CogVideoX.")
+ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
+ pass
diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py
new file mode 100644
index 000000000000..78f71527cb7e
--- /dev/null
+++ b/tests/lora/test_lora_layers_sana.py
@@ -0,0 +1,138 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import torch
+from transformers import Gemma2Model, GemmaTokenizer
+
+from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
+from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = SanaPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ scheduler_kwargs = {}
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ transformer_kwargs = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_layers": 1,
+ "num_attention_heads": 2,
+ "attention_head_dim": 4,
+ "num_cross_attention_heads": 2,
+ "cross_attention_head_dim": 4,
+ "cross_attention_dim": 8,
+ "caption_channels": 8,
+ "sample_size": 32,
+ }
+ transformer_cls = SanaTransformer2DModel
+ vae_kwargs = {
+ "in_channels": 3,
+ "latent_channels": 4,
+ "attention_head_dim": 2,
+ "encoder_block_types": (
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ "decoder_block_types": (
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ "encoder_block_out_channels": (8, 8),
+ "decoder_block_out_channels": (8, 8),
+ "encoder_qkv_multiscales": ((), (5,)),
+ "decoder_qkv_multiscales": ((), (5,)),
+ "encoder_layers_per_block": (1, 1),
+ "decoder_layers_per_block": [1, 1],
+ "downsample_block_type": "conv",
+ "upsample_block_type": "interpolate",
+ "decoder_norm_types": "rms_norm",
+ "decoder_act_fns": "silu",
+ "scaling_factor": 0.41407,
+ }
+ vae_cls = AutoencoderDC
+ tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
+ text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
+
+ @property
+ def output_shape(self):
+ return (1, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "num_inference_steps": 4,
+ "guidance_scale": 4.5,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ "complex_human_instruction": None,
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ @unittest.skip("Not supported in SANA.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Not supported in SANA.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in SANA.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in SANA.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in SANA.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in SANA.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in SANA.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in SANA.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py
index 50187e50a912..3eefa97663e6 100644
--- a/tests/lora/test_lora_layers_sd.py
+++ b/tests/lora/test_lora_layers_sd.py
@@ -33,10 +33,12 @@
)
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
load_image,
+ nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -100,7 +102,7 @@ def tearDown(self):
# Keeping this test here makes sense because it doesn't look any integration
# (value assertions on logits).
@slow
- @require_torch_gpu
+ @require_torch_accelerator
def test_integration_move_lora_cpu(self):
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
lora_id = "takuma104/lora-test-text-encoder-lora-target"
@@ -157,7 +159,7 @@ def test_integration_move_lora_cpu(self):
self.assertTrue(m.weight.device != torch.device("cpu"))
@slow
- @require_torch_gpu
+ @require_torch_accelerator
def test_integration_move_lora_dora_cpu(self):
from peft import LoraConfig
@@ -207,18 +209,19 @@ def test_integration_move_lora_dora_cpu(self):
@slow
-@require_torch_gpu
+@nightly
+@require_torch_accelerator
@require_peft_backend
class LoraIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_integration_logits_with_scale(self):
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
@@ -376,7 +379,7 @@ def test_a1111_with_model_cpu_offload(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@@ -398,7 +401,7 @@ def test_a1111_with_sequential_cpu_offload(self):
generator = torch.Generator().manual_seed(0)
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@@ -654,7 +657,7 @@ def test_sd_load_civitai_empty_network_alpha(self):
See: https://github.com/huggingface/diffusers/issues/5606
"""
pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipeline.enable_sequential_cpu_offload()
+ pipeline.enable_sequential_cpu_offload(device=torch_device)
civitai_path = hf_hub_download("ybelkada/test-ahi-civitai", "ahi_lora_weights.safetensors")
pipeline.load_lora_weights(civitai_path, adapter_name="ahri")
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index 78d4b786d21b..90aaa3bcfe78 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -17,6 +17,7 @@
import unittest
import numpy as np
+import pytest
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -29,17 +30,17 @@
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
- is_peft_available,
+ backend_empty_cache,
+ is_flaky,
+ nightly,
numpy_cosine_similarity_distance,
+ require_big_gpu_with_torch_cuda,
require_peft_backend,
- require_torch_gpu,
+ require_torch_accelerator,
torch_device,
)
-if is_peft_available():
- pass
-
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
@@ -93,7 +94,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def output_shape(self):
return (1, 32, 32, 3)
- @require_torch_gpu
+ @require_torch_accelerator
def test_sd3_lora(self):
"""
Test loading the loras that are saved with the diffusers and peft formats.
@@ -129,22 +130,29 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
def test_modify_padding_mode(self):
pass
+ @is_flaky
+ def test_multiple_wrong_adapter_name_raises_error(self):
+ super().test_multiple_wrong_adapter_name_raises_error()
+
-@require_torch_gpu
+@nightly
+@require_torch_accelerator
@require_peft_backend
-class LoraSD3IntegrationTests(unittest.TestCase):
+@require_big_gpu_with_torch_cuda
+@pytest.mark.big_gpu_with_torch_cuda
+class SD3LoraIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
init_image = load_image(
@@ -166,47 +174,16 @@ def get_inputs(self, device, seed=0):
def test_sd3_img2img_lora(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
- pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
- pipe.enable_sequential_cpu_offload()
+ pipe.load_lora_weights("zwloong/sd3-lora-training-rank16-v2")
+ pipe.fuse_lora()
+ pipe.unload_lora_weights()
+ pipe = pipe.to(torch_device)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
- image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- 0.47827148,
- 0.5,
- 0.71972656,
- 0.3955078,
- 0.4194336,
- 0.69628906,
- 0.37036133,
- 0.40820312,
- 0.6923828,
- 0.36450195,
- 0.40429688,
- 0.6904297,
- 0.35595703,
- 0.39257812,
- 0.68652344,
- 0.35498047,
- 0.3984375,
- 0.68310547,
- 0.34716797,
- 0.3996582,
- 0.6855469,
- 0.3388672,
- 0.3959961,
- 0.6816406,
- 0.34033203,
- 0.40429688,
- 0.6845703,
- 0.34228516,
- 0.4086914,
- 0.6870117,
- ]
- )
+ image_slice = image[0, -3:, -3:]
+ expected_slice = np.array([0.5649, 0.5405, 0.5488, 0.5688, 0.5449, 0.5513, 0.5337, 0.5107, 0.5059])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py
index 94a44ed8f9ec..76d6dc48602b 100644
--- a/tests/lora/test_lora_layers_sdxl.py
+++ b/tests/lora/test_lora_layers_sdxl.py
@@ -37,6 +37,7 @@
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
CaptureLogger,
+ is_flaky,
load_image,
nightly,
numpy_cosine_similarity_distance,
@@ -111,8 +112,13 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
+ @is_flaky
+ def test_multiple_wrong_adapter_name_raises_error(self):
+ super().test_multiple_wrong_adapter_name_raises_error()
+
@slow
+@nightly
@require_torch_gpu
@require_peft_backend
class LoraSDXLIntegrationTests(unittest.TestCase):
diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py
new file mode 100644
index 000000000000..c2498fa68c3d
--- /dev/null
+++ b/tests/lora/test_lora_layers_wan.py
@@ -0,0 +1,143 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ WanPipeline,
+ WanTransformer3DModel,
+)
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ require_peft_backend,
+ skip_mps,
+)
+
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = WanPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 16,
+ "out_channels": 16,
+ "text_dim": 32,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ transformer_cls = WanTransformer3DModel
+ vae_kwargs = {
+ "base_dim": 3,
+ "z_dim": 16,
+ "dim_mult": [1, 1, 1, 1],
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+ vae_cls = AutoencoderKLWan
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+
+ @property
+ def output_shape(self):
+ return (1, 9, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ sizes = (4, 4)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "",
+ "num_frames": num_frames,
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in Wan.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Wan.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Wan.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index e7fc840fcaa5..8cdb43c9d085 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -14,11 +14,13 @@
# limitations under the License.
import inspect
import os
+import re
import tempfile
import unittest
from itertools import product
import numpy as np
+import pytest
import torch
from diffusers import (
@@ -32,6 +34,7 @@
from diffusers.utils.testing_utils import (
CaptureLogger,
floats_tensor,
+ is_torch_version,
require_peft_backend,
require_peft_version_greater,
require_transformers_version_greater,
@@ -74,6 +77,9 @@ def initialize_dummy_state_dict(state_dict):
return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}
+POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
+
+
@require_peft_backend
class PeftLoraLoaderMixinTests:
pipeline_class = None
@@ -84,12 +90,12 @@ class PeftLoraLoaderMixinTests:
has_two_text_encoders = False
has_three_text_encoders = False
- text_encoder_cls, text_encoder_id = None, None
- text_encoder_2_cls, text_encoder_2_id = None, None
- text_encoder_3_cls, text_encoder_3_id = None, None
- tokenizer_cls, tokenizer_id = None, None
- tokenizer_2_cls, tokenizer_2_id = None, None
- tokenizer_3_cls, tokenizer_3_id = None, None
+ text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, ""
+ text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, ""
+ text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, ""
+ tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
+ tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
+ tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
unet_kwargs = None
transformer_cls = None
@@ -119,16 +125,26 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False):
torch.manual_seed(0)
vae = self.vae_cls(**self.vae_kwargs)
- text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
- tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
+ text_encoder = self.text_encoder_cls.from_pretrained(
+ self.text_encoder_id, subfolder=self.text_encoder_subfolder
+ )
+ tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)
if self.text_encoder_2_cls is not None:
- text_encoder_2 = self.text_encoder_2_cls.from_pretrained(self.text_encoder_2_id)
- tokenizer_2 = self.tokenizer_2_cls.from_pretrained(self.tokenizer_2_id)
+ text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
+ self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder
+ )
+ tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
+ self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
+ )
if self.text_encoder_3_cls is not None:
- text_encoder_3 = self.text_encoder_3_cls.from_pretrained(self.text_encoder_3_id)
- tokenizer_3 = self.tokenizer_3_cls.from_pretrained(self.tokenizer_3_id)
+ text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
+ self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder
+ )
+ tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
+ self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
+ )
text_lora_config = LoraConfig(
r=rank,
@@ -427,7 +443,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
- for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
+ for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
@@ -788,7 +804,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
and makes sure it works as expected
"""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
- for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
+ for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
@@ -1119,6 +1135,43 @@ def test_wrong_adapter_name_raises_error(self):
pipe.set_adapters("adapter-1")
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ def test_multiple_wrong_adapter_name_raises_error(self):
+ scheduler_cls = self.scheduler_classes[0]
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
+ logger = logging.get_logger("diffusers.loaders.lora_base")
+ logger.setLevel(30)
+ with CaptureLogger(logger) as cap_logger:
+ pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components)
+
+ wrong_components = sorted(set(scale_with_wrong_components.keys()))
+ msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
+ self.assertTrue(msg in str(cap_logger.out))
+
+ # test this works.
+ pipe.set_adapters("adapter-1")
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
def test_simple_inference_with_text_denoiser_block_scale(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
@@ -1510,6 +1563,11 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
)
@skip_mps
+ @pytest.mark.xfail(
+ condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
+ reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
+ strict=False,
+ )
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -1535,7 +1593,18 @@ def test_lora_fuse_nan(self):
"adapter-1"
].weight += float("inf")
else:
- pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
+ named_modules = [name for name, _ in pipe.transformer.named_modules()]
+ tower_name = (
+ "transformer_blocks"
+ if any(name == "transformer_blocks" for name in named_modules)
+ else "blocks"
+ )
+ transformer_tower = getattr(pipe.transformer, tower_name)
+ has_attn1 = any("attn1" in name for name in named_modules)
+ if has_attn1:
+ transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
+ else:
+ transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
@@ -1543,7 +1612,7 @@ def test_lora_fuse_nan(self):
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
- out = pipe("test", num_inference_steps=2, output_type="np")[0]
+ out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
@@ -1784,11 +1853,7 @@ def test_missing_keys_warning(self):
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]
- logger = (
- logging.get_logger("diffusers.loaders.unet")
- if self.unet_kwargs is not None
- else logging.get_logger("diffusers.loaders.lora_pipeline")
- )
+ logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
@@ -1823,11 +1888,7 @@ def test_unexpected_keys_warning(self):
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
- logger = (
- logging.get_logger("diffusers.loaders.unet")
- if self.unet_kwargs is not None
- else logging.get_logger("diffusers.loaders.lora_pipeline")
- )
+ logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
@@ -1886,3 +1947,391 @@ def set_pad_mode(network, mode="circular"):
_, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs)[0]
+
+ def test_logs_info_when_no_lora_keys_found(self):
+ scheduler_cls = self.scheduler_classes[0]
+ # Skip text encoder check for now as that is handled with `transformers`.
+ components, _, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
+ logger = logging.get_logger("diffusers.loaders.peft")
+ logger.setLevel(logging.WARNING)
+
+ with CaptureLogger(logger) as cap_logger:
+ pipe.load_lora_weights(no_op_state_dict)
+ out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
+ self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
+ self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
+
+ # test only for text encoder
+ for lora_module in self.pipeline_class._lora_loadable_modules:
+ if "text_encoder" in lora_module:
+ text_encoder = getattr(pipe, lora_module)
+ if lora_module == "text_encoder":
+ prefix = "text_encoder"
+ elif lora_module == "text_encoder_2":
+ prefix = "text_encoder_2"
+
+ logger = logging.get_logger("diffusers.loaders.lora_base")
+ logger.setLevel(logging.WARNING)
+
+ with CaptureLogger(logger) as cap_logger:
+ self.pipeline_class.load_lora_into_text_encoder(
+ no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
+ )
+
+ self.assertTrue(
+ cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
+ )
+
+ def test_set_adapters_match_attention_kwargs(self):
+ """Test to check if outputs after `set_adapters()` and attention kwargs match."""
+ call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
+ for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
+ if possible_attention_kwargs in call_signature_keys:
+ attention_kwargs_name = possible_attention_kwargs
+ break
+ assert attention_kwargs_name is not None
+
+ for scheduler_cls in self.scheduler_classes:
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ )
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ lora_scale = 0.5
+ attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
+ output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+
+ pipe.set_adapters("default", lora_scale)
+ output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
+ "Lora + scale should match the output of `set_adapters()`.",
+ )
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
+
+ output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results as attention_kwargs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results as set_adapters().",
+ )
+
+ @require_peft_version_greater("0.13.2")
+ def test_lora_B_bias(self):
+ # Currently, this test is only relevant for Flux Control LoRA as we are not
+ # aware of any other LoRA checkpoint that has its `lora_B` biases trained.
+ components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # keep track of the bias values of the base layers to perform checks later.
+ bias_values = {}
+ denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
+ for name, module in denoiser.named_modules():
+ if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
+ if module.bias is not None:
+ bias_values[name] = module.bias.data.clone()
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.INFO)
+
+ original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ denoiser_lora_config.lora_bias = False
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+ lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.delete_adapters("adapter-1")
+
+ denoiser_lora_config.lora_bias = True
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+ lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
+ self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
+ self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
+
+ def test_correct_lora_configs_with_different_ranks(self):
+ components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+
+ lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ if self.unet_kwargs is not None:
+ pipe.unet.delete_adapters("adapter-1")
+ else:
+ pipe.transformer.delete_adapters("adapter-1")
+
+ denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
+ for name, _ in denoiser.named_modules():
+ if "to_k" in name and "attn" in name and "lora" not in name:
+ module_name_to_rank_update = name.replace(".base_layer.", ".")
+ break
+
+ # change the rank_pattern
+ updated_rank = denoiser_lora_config.r * 2
+ denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
+
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
+ updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+ updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
+
+ self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
+
+ lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
+ self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
+
+ if self.unet_kwargs is not None:
+ pipe.unet.delete_adapters("adapter-1")
+ else:
+ pipe.transformer.delete_adapters("adapter-1")
+
+ # similarly change the alpha_pattern
+ updated_alpha = denoiser_lora_config.lora_alpha * 2
+ denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(
+ pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
+ )
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(
+ pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
+ )
+
+ lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
+ self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
+
+ def test_layerwise_casting_inference_denoiser(self):
+ from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
+
+ def check_linear_dtype(module, storage_dtype, compute_dtype):
+ patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
+ if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
+ patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
+ for name, submodule in module.named_modules():
+ if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
+ continue
+ dtype_to_check = storage_dtype
+ if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
+ dtype_to_check = compute_dtype
+ if getattr(submodule, "weight", None) is not None:
+ self.assertEqual(submodule.weight.dtype, dtype_to_check)
+ if getattr(submodule, "bias", None) is not None:
+ self.assertEqual(submodule.bias.dtype, dtype_to_check)
+
+ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device, dtype=compute_dtype)
+ pipe.set_progress_bar_config(disable=None)
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ )
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ if storage_dtype is not None:
+ denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
+ check_linear_dtype(denoiser, storage_dtype, compute_dtype)
+
+ return pipe
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe_fp32 = initialize_pipeline(storage_dtype=None)
+ pipe_fp32(**inputs, generator=torch.manual_seed(0))[0]
+
+ pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
+ pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
+
+ pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
+ pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
+
+ @require_peft_version_greater("0.14.0")
+ def test_layerwise_casting_peft_input_autocast_denoiser(self):
+ r"""
+ A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
+ is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
+ cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
+ In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0,
+ this test will fail with the following error:
+
+ ```
+ RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
+ ```
+
+ See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
+ """
+
+ from diffusers.hooks.layerwise_casting import (
+ _PEFT_AUTOCAST_DISABLE_HOOK,
+ DEFAULT_SKIP_MODULES_PATTERN,
+ SUPPORTED_PYTORCH_LAYERS,
+ apply_layerwise_casting,
+ )
+
+ storage_dtype = torch.float8_e4m3fn
+ compute_dtype = torch.float32
+
+ def check_module(denoiser):
+ # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
+ for name, module in denoiser.named_modules():
+ if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
+ continue
+ dtype_to_check = storage_dtype
+ if any(re.search(pattern, name) for pattern in patterns_to_check):
+ dtype_to_check = compute_dtype
+ if getattr(module, "weight", None) is not None:
+ self.assertEqual(module.weight.dtype, dtype_to_check)
+ if getattr(module, "bias", None) is not None:
+ self.assertEqual(module.bias.dtype, dtype_to_check)
+ if isinstance(module, BaseTunerLayer):
+ self.assertTrue(getattr(module, "_diffusers_hook", None) is not None)
+ self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
+
+ # 1. Test forward with add_adapter
+ components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device, dtype=compute_dtype)
+ pipe.set_progress_bar_config(disable=None)
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
+ if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None:
+ patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns)
+
+ apply_layerwise_casting(
+ denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check
+ )
+ check_module(denoiser)
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ # 2. Test forward with load_lora_weights
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device, dtype=compute_dtype)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ apply_layerwise_casting(
+ denoiser,
+ storage_dtype=storage_dtype,
+ compute_dtype=compute_dtype,
+ skip_modules_pattern=patterns_to_check,
+ )
+ check_module(denoiser)
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe(**inputs, generator=torch.manual_seed(0))[0]
diff --git a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
new file mode 100644
index 000000000000..7efb390287ab
--- /dev/null
+++ b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
@@ -0,0 +1,261 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from diffusers import AsymmetricAutoencoderKL
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ floats_tensor,
+ load_hf_numpy,
+ require_torch_accelerator,
+ require_torch_gpu,
+ skip_mps,
+ slow,
+ torch_all_close,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AsymmetricAutoencoderKL
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_asym_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
+ init_dict = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
+ "down_block_out_channels": block_out_channels,
+ "layers_per_down_block": 1,
+ "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
+ "up_block_out_channels": block_out_channels,
+ "layers_per_up_block": 1,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": norm_num_groups,
+ "sample_size": 32,
+ "scaling_factor": 0.18215,
+ }
+ return init_dict
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+ mask = torch.ones((batch_size, 1) + sizes).to(torch_device)
+
+ return {"sample": image, "mask": mask}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_asym_autoencoder_kl_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skip("Unsupported test.")
+ def test_forward_with_norm_groups(self):
+ pass
+
+
+@slow
+class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
+ def get_file_format(self, seed, shape):
+ return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
+
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
+ dtype = torch.float16 if fp16 else torch.float32
+ image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
+ return image
+
+ def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
+ revision = "main"
+ torch_dtype = torch.float32
+
+ model = AsymmetricAutoencoderKL.from_pretrained(
+ model_id,
+ torch_dtype=torch_dtype,
+ revision=revision,
+ )
+ model.to(torch_device).eval()
+
+ return model
+
+ def get_generator(self, seed=0):
+ generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
+ if torch_device != "mps":
+ return torch.Generator(device=generator_device).manual_seed(seed)
+ return torch.manual_seed(seed)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [
+ 33,
+ [-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
+ [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
+ ],
+ [
+ 47,
+ [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
+ [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
+ ],
+ # fmt: on
+ ]
+ )
+ def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
+ model = self.get_sd_vae_model()
+ image = self.get_sd_image(seed)
+ generator = self.get_generator(seed)
+
+ with torch.no_grad():
+ sample = model(image, generator=generator, sample_posterior=True).sample
+
+ assert sample.shape == image.shape
+
+ output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
+ expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [
+ 33,
+ [-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097],
+ [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078],
+ ],
+ [
+ 47,
+ [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
+ [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
+ ],
+ # fmt: on
+ ]
+ )
+ def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
+ model = self.get_sd_vae_model()
+ image = self.get_sd_image(seed)
+
+ with torch.no_grad():
+ sample = model(image).sample
+
+ assert sample.shape == image.shape
+
+ output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
+ expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
+ [37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
+ # fmt: on
+ ]
+ )
+ @require_torch_accelerator
+ @skip_mps
+ def test_stable_diffusion_decode(self, seed, expected_slice):
+ model = self.get_sd_vae_model()
+ encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
+
+ with torch.no_grad():
+ sample = model.decode(encoding).sample
+
+ assert list(sample.shape) == [3, 3, 512, 512]
+
+ output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
+ expected_output_slice = torch.tensor(expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=2e-3)
+
+ @parameterized.expand([(13,), (16,), (37,)])
+ @require_torch_gpu
+ @unittest.skipIf(
+ not is_xformers_available(),
+ reason="xformers is not required when using PyTorch 2.0.",
+ )
+ def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
+ model = self.get_sd_vae_model()
+ encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
+
+ with torch.no_grad():
+ sample = model.decode(encoding).sample
+
+ model.enable_xformers_memory_efficient_attention()
+ with torch.no_grad():
+ sample_2 = model.decode(encoding).sample
+
+ assert list(sample.shape) == [3, 3, 512, 512]
+
+ assert torch_all_close(sample, sample_2, atol=5e-2)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
+ [47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
+ # fmt: on
+ ]
+ )
+ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
+ model = self.get_sd_vae_model()
+ image = self.get_sd_image(seed)
+ generator = self.get_generator(seed)
+
+ with torch.no_grad():
+ dist = model.encode(image).latent_dist
+ sample = dist.sample(generator=generator)
+
+ assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
+
+ output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
+ expected_output_slice = torch.tensor(expected_slice)
+
+ tolerance = 3e-3 if torch_device != "mps" else 1e-2
+ assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py
new file mode 100644
index 000000000000..5f21593d8e04
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_dc.py
@@ -0,0 +1,87 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderDC
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderDC
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_dc_config(self):
+ return {
+ "in_channels": 3,
+ "latent_channels": 4,
+ "attention_head_dim": 2,
+ "encoder_block_types": (
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ "decoder_block_types": (
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ "encoder_block_out_channels": (8, 8),
+ "decoder_block_out_channels": (8, 8),
+ "encoder_qkv_multiscales": ((), (5,)),
+ "decoder_qkv_multiscales": ((), (5,)),
+ "encoder_layers_per_block": (1, 1),
+ "decoder_layers_per_block": [1, 1],
+ "downsample_block_type": "conv",
+ "upsample_block_type": "interpolate",
+ "decoder_norm_types": "rms_norm",
+ "decoder_act_fns": "silu",
+ "scaling_factor": 0.41407,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_dc_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
+ def test_forward_with_norm_groups(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
new file mode 100644
index 000000000000..00d4b8ed2b5f
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
@@ -0,0 +1,210 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import AutoencoderKLHunyuanVideo
+from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLHunyuanVideo
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_hunyuan_video_config(self):
+ return {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 4,
+ "down_block_types": (
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ "up_block_types": (
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ "block_out_channels": (8, 8, 8, 8),
+ "layers_per_block": 1,
+ "act_fn": "silu",
+ "norm_num_groups": 4,
+ "scaling_factor": 0.476986,
+ "spatial_compression_ratio": 8,
+ "temporal_compression_ratio": 4,
+ "mid_block_add_attention": True,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_hunyuan_video_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_enable_disable_tiling(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_tiling()
+ output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE tiling should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ "Without tiling outputs should match with the outputs when tiling is manually disabled.",
+ )
+
+ def test_enable_disable_slicing(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE slicing should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ "Without slicing outputs should match with the outputs when slicing is manually disabled.",
+ )
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "HunyuanVideoDecoder3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoEncoder3D",
+ "HunyuanVideoMidBlock3D",
+ "HunyuanVideoUpBlock3D",
+ }
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ # We need to overwrite this test because the base test does not account length of down_block_types
+ def test_forward_with_norm_groups(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["norm_num_groups"] = 16
+ init_dict["block_out_channels"] = (16, 16, 16, 16)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ @unittest.skip("Unsupported test.")
+ def test_outputs_equivalence(self):
+ pass
+
+ def test_prepare_causal_attention_mask(self):
+ def prepare_causal_attention_mask_orig(
+ num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
+ ) -> torch.Tensor:
+ seq_len = num_frames * height_width
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
+ for i in range(seq_len):
+ i_frame = i // height_width
+ mask[i, : (i_frame + 1) * height_width] = 0
+ if batch_size is not None:
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
+ return mask
+
+ # test with some odd shapes
+ original_mask = prepare_causal_attention_mask_orig(
+ num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
+ )
+ new_mask = prepare_causal_attention_mask(
+ num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
+ )
+ self.assertTrue(
+ torch.allclose(original_mask, new_mask),
+ "Causal attention mask should be the same",
+ )
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py
new file mode 100644
index 000000000000..9126594000f6
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_kl.py
@@ -0,0 +1,468 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from diffusers import AutoencoderKL
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ floats_tensor,
+ load_hf_numpy,
+ require_torch_accelerator,
+ require_torch_accelerator_with_fp16,
+ require_torch_gpu,
+ skip_mps,
+ slow,
+ torch_all_close,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKL
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
+ init_dict = {
+ "block_out_channels": block_out_channels,
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
+ "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
+ "latent_channels": 4,
+ "norm_num_groups": norm_num_groups,
+ }
+ return init_dict
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_enable_disable_tiling(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_tiling()
+ output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE tiling should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ "Without tiling outputs should match with the outputs when tiling is manually disabled.",
+ )
+
+ def test_enable_disable_slicing(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE slicing should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ "Without slicing outputs should match with the outputs when slicing is manually disabled.",
+ )
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ def test_from_pretrained_hub(self):
+ model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
+ self.assertIsNotNone(model)
+ self.assertEqual(len(loading_info["missing_keys"]), 0)
+
+ model.to(torch_device)
+ image = model(**self.dummy_input)
+
+ assert image is not None, "Make sure output is not None"
+
+ def test_output_pretrained(self):
+ model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
+ model = model.to(torch_device)
+ model.eval()
+
+ # Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
+ generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
+ if torch_device != "mps":
+ generator = torch.Generator(device=generator_device).manual_seed(0)
+ else:
+ generator = torch.manual_seed(0)
+
+ image = torch.randn(
+ 1,
+ model.config.in_channels,
+ model.config.sample_size,
+ model.config.sample_size,
+ generator=torch.manual_seed(0),
+ )
+ image = image.to(torch_device)
+ with torch.no_grad():
+ output = model(image, sample_posterior=True, generator=generator).sample
+
+ output_slice = output[0, -1, -3:, -3:].flatten().cpu()
+
+ # Since the VAE Gaussian prior's generator is seeded on the appropriate device,
+ # the expected output slices are not the same for CPU and GPU.
+ if torch_device == "mps":
+ expected_output_slice = torch.tensor(
+ [
+ -4.0078e-01,
+ -3.8323e-04,
+ -1.2681e-01,
+ -1.1462e-01,
+ 2.0095e-01,
+ 1.0893e-01,
+ -8.8247e-02,
+ -3.0361e-01,
+ -9.8644e-03,
+ ]
+ )
+ elif generator_device == "cpu":
+ expected_output_slice = torch.tensor(
+ [
+ -0.1352,
+ 0.0878,
+ 0.0419,
+ -0.0818,
+ -0.1069,
+ 0.0688,
+ -0.1458,
+ -0.4446,
+ -0.0026,
+ ]
+ )
+ else:
+ expected_output_slice = torch.tensor(
+ [
+ -0.2421,
+ 0.4642,
+ 0.2507,
+ -0.0438,
+ 0.0682,
+ 0.3160,
+ -0.2018,
+ -0.0727,
+ 0.2485,
+ ]
+ )
+
+ self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
+
+
+@slow
+class AutoencoderKLIntegrationTests(unittest.TestCase):
+ def get_file_format(self, seed, shape):
+ return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
+
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
+ dtype = torch.float16 if fp16 else torch.float32
+ image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
+ return image
+
+ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
+ revision = "fp16" if fp16 else None
+ torch_dtype = torch.float16 if fp16 else torch.float32
+
+ model = AutoencoderKL.from_pretrained(
+ model_id,
+ subfolder="vae",
+ torch_dtype=torch_dtype,
+ revision=revision,
+ )
+ model.to(torch_device)
+
+ return model
+
+ def get_generator(self, seed=0):
+ generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
+ if torch_device != "mps":
+ return torch.Generator(device=generator_device).manual_seed(seed)
+ return torch.manual_seed(seed)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [
+ 33,
+ [-0.1556, 0.9848, -0.0410, -0.0642, -0.2685, 0.8381, -0.2004, -0.0700],
+ [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824],
+ ],
+ [
+ 47,
+ [-0.2376, 0.1200, 0.1337, -0.4830, -0.2504, -0.0759, -0.0486, -0.4077],
+ [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131],
+ ],
+ # fmt: on
+ ]
+ )
+ def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
+ model = self.get_sd_vae_model()
+ image = self.get_sd_image(seed)
+ generator = self.get_generator(seed)
+
+ with torch.no_grad():
+ sample = model(image, generator=generator, sample_posterior=True).sample
+
+ assert sample.shape == image.shape
+
+ output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
+ expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [33, [-0.0513, 0.0289, 1.3799, 0.2166, -0.2573, -0.0871, 0.5103, -0.0999]],
+ [47, [-0.4128, -0.1320, -0.3704, 0.1965, -0.4116, -0.2332, -0.3340, 0.2247]],
+ # fmt: on
+ ]
+ )
+ @require_torch_accelerator_with_fp16
+ def test_stable_diffusion_fp16(self, seed, expected_slice):
+ model = self.get_sd_vae_model(fp16=True)
+ image = self.get_sd_image(seed, fp16=True)
+ generator = self.get_generator(seed)
+
+ with torch.no_grad():
+ sample = model(image, generator=generator, sample_posterior=True).sample
+
+ assert sample.shape == image.shape
+
+ output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
+ expected_output_slice = torch.tensor(expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=1e-2)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [
+ 33,
+ [-0.1609, 0.9866, -0.0487, -0.0777, -0.2716, 0.8368, -0.2055, -0.0814],
+ [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824],
+ ],
+ [
+ 47,
+ [-0.2377, 0.1147, 0.1333, -0.4841, -0.2506, -0.0805, -0.0491, -0.4085],
+ [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131],
+ ],
+ # fmt: on
+ ]
+ )
+ def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
+ model = self.get_sd_vae_model()
+ image = self.get_sd_image(seed)
+
+ with torch.no_grad():
+ sample = model(image).sample
+
+ assert sample.shape == image.shape
+
+ output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
+ expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [13, [-0.2051, -0.1803, -0.2311, -0.2114, -0.3292, -0.3574, -0.2953, -0.3323]],
+ [37, [-0.2632, -0.2625, -0.2199, -0.2741, -0.4539, -0.4990, -0.3720, -0.4925]],
+ # fmt: on
+ ]
+ )
+ @require_torch_accelerator
+ @skip_mps
+ def test_stable_diffusion_decode(self, seed, expected_slice):
+ model = self.get_sd_vae_model()
+ encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
+
+ with torch.no_grad():
+ sample = model.decode(encoding).sample
+
+ assert list(sample.shape) == [3, 3, 512, 512]
+
+ output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
+ expected_output_slice = torch.tensor(expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [27, [-0.0369, 0.0207, -0.0776, -0.0682, -0.1747, -0.1930, -0.1465, -0.2039]],
+ [16, [-0.1628, -0.2134, -0.2747, -0.2642, -0.3774, -0.4404, -0.3687, -0.4277]],
+ # fmt: on
+ ]
+ )
+ @require_torch_accelerator_with_fp16
+ def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
+ model = self.get_sd_vae_model(fp16=True)
+ encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
+
+ with torch.no_grad():
+ sample = model.decode(encoding).sample
+
+ assert list(sample.shape) == [3, 3, 512, 512]
+
+ output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
+ expected_output_slice = torch.tensor(expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
+
+ @parameterized.expand([(13,), (16,), (27,)])
+ @require_torch_gpu
+ @unittest.skipIf(
+ not is_xformers_available(),
+ reason="xformers is not required when using PyTorch 2.0.",
+ )
+ def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
+ model = self.get_sd_vae_model(fp16=True)
+ encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
+
+ with torch.no_grad():
+ sample = model.decode(encoding).sample
+
+ model.enable_xformers_memory_efficient_attention()
+ with torch.no_grad():
+ sample_2 = model.decode(encoding).sample
+
+ assert list(sample.shape) == [3, 3, 512, 512]
+
+ assert torch_all_close(sample, sample_2, atol=1e-1)
+
+ @parameterized.expand([(13,), (16,), (37,)])
+ @require_torch_gpu
+ @unittest.skipIf(
+ not is_xformers_available(),
+ reason="xformers is not required when using PyTorch 2.0.",
+ )
+ def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
+ model = self.get_sd_vae_model()
+ encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
+
+ with torch.no_grad():
+ sample = model.decode(encoding).sample
+
+ model.enable_xformers_memory_efficient_attention()
+ with torch.no_grad():
+ sample_2 = model.decode(encoding).sample
+
+ assert list(sample.shape) == [3, 3, 512, 512]
+
+ assert torch_all_close(sample, sample_2, atol=1e-2)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
+ [47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
+ # fmt: on
+ ]
+ )
+ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
+ model = self.get_sd_vae_model()
+ image = self.get_sd_image(seed)
+ generator = self.get_generator(seed)
+
+ with torch.no_grad():
+ dist = model.encode(image).latent_dist
+ sample = dist.sample(generator=generator)
+
+ assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
+
+ output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
+ expected_output_slice = torch.tensor(expected_slice)
+
+ tolerance = 3e-3 if torch_device != "mps" else 1e-2
+ assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
new file mode 100644
index 000000000000..7336bb3d3e97
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
@@ -0,0 +1,179 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import AutoencoderKLCogVideoX
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLCogVideoX
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_cogvideox_config(self):
+ return {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": (
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ ),
+ "up_block_types": (
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ ),
+ "block_out_channels": (8, 8, 8, 8),
+ "latent_channels": 4,
+ "layers_per_block": 1,
+ "norm_num_groups": 2,
+ "temporal_compression_ratio": 4,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_frames = 8
+ num_channels = 3
+ sizes = (16, 16)
+
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 8, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 8, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_cogvideox_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_enable_disable_tiling(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_tiling()
+ output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE tiling should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ "Without tiling outputs should match with the outputs when tiling is manually disabled.",
+ )
+
+ def test_enable_disable_slicing(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE slicing should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ "Without slicing outputs should match with the outputs when slicing is manually disabled.",
+ )
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "CogVideoXDownBlock3D",
+ "CogVideoXDecoder3D",
+ "CogVideoXEncoder3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXMidBlock3D",
+ }
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ def test_forward_with_norm_groups(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["norm_num_groups"] = 16
+ init_dict["block_out_channels"] = (16, 32, 32, 32)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ @unittest.skip("Unsupported test.")
+ def test_outputs_equivalence(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
new file mode 100644
index 000000000000..cf80ff50443e
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
@@ -0,0 +1,73 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderKLTemporalDecoder
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLTemporalDecoder
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ @property
+ def dummy_input(self):
+ batch_size = 3
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+ num_frames = 3
+
+ return {"sample": image, "num_frames": num_frames}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "block_out_channels": [32, 64],
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ @unittest.skip("Test unsupported.")
+ def test_forward_with_norm_groups(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
new file mode 100644
index 000000000000..66d170b28eee
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
@@ -0,0 +1,200 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import AutoencoderKLLTXVideo
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLLTXVideo
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_ltx_video_config(self):
+ return {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 8,
+ "block_out_channels": (8, 8, 8, 8),
+ "decoder_block_out_channels": (8, 8, 8, 8),
+ "layers_per_block": (1, 1, 1, 1, 1),
+ "decoder_layers_per_block": (1, 1, 1, 1, 1),
+ "spatio_temporal_scaling": (True, True, False, False),
+ "decoder_spatio_temporal_scaling": (True, True, False, False),
+ "decoder_inject_noise": (False, False, False, False, False),
+ "upsample_residual": (False, False, False, False),
+ "upsample_factor": (1, 1, 1, 1),
+ "timestep_conditioning": False,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_ltx_video_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "LTXVideoEncoder3d",
+ "LTXVideoDecoder3d",
+ "LTXVideoDownBlock3D",
+ "LTXVideoMidBlock3d",
+ "LTXVideoUpBlock3d",
+ }
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ @unittest.skip("Unsupported test.")
+ def test_outputs_equivalence(self):
+ pass
+
+ @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
+ def test_forward_with_norm_groups(self):
+ pass
+
+
+class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLLTXVideo
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_ltx_video_config(self):
+ return {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 8,
+ "block_out_channels": (8, 8, 8, 8),
+ "decoder_block_out_channels": (16, 32, 64),
+ "layers_per_block": (1, 1, 1, 1),
+ "decoder_layers_per_block": (1, 1, 1, 1),
+ "spatio_temporal_scaling": (True, True, True, False),
+ "decoder_spatio_temporal_scaling": (True, True, True),
+ "decoder_inject_noise": (True, True, True, False),
+ "upsample_residual": (True, True, True),
+ "upsample_factor": (2, 2, 2),
+ "timestep_conditioning": True,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ timestep = torch.tensor([0.05] * batch_size, device=torch_device)
+
+ return {"sample": image, "temb": timestep}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_ltx_video_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "LTXVideoEncoder3d",
+ "LTXVideoDecoder3d",
+ "LTXVideoDownBlock3D",
+ "LTXVideoMidBlock3d",
+ "LTXVideoUpBlock3d",
+ }
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ @unittest.skip("Unsupported test.")
+ def test_outputs_equivalence(self):
+ pass
+
+ @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ def test_enable_disable_tiling(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_tiling()
+ output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE tiling should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ "Without tiling outputs should match with the outputs when tiling is manually disabled.",
+ )
diff --git a/tests/models/autoencoders/test_models_autoencoder_magvit.py b/tests/models/autoencoders/test_models_autoencoder_magvit.py
new file mode 100644
index 000000000000..ee7e5bbdd485
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_magvit.py
@@ -0,0 +1,90 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderKLMagvit
+from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLMagvit
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_magvit_config(self):
+ return {
+ "in_channels": 3,
+ "latent_channels": 4,
+ "out_channels": 3,
+ "block_out_channels": [8, 8, 8, 8],
+ "down_block_types": [
+ "SpatialDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ ],
+ "up_block_types": [
+ "SpatialUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ ],
+ "layers_per_block": 1,
+ "norm_num_groups": 8,
+ "spatial_group_norm": True,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ height = 16
+ width = 16
+
+ image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_magvit_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"EasyAnimateEncoder", "EasyAnimateDecoder"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ @unittest.skip("Not quite sure why this test fails. Revisit later.")
+ def test_effective_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip("Unsupported test.")
+ def test_forward_with_norm_groups(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py
new file mode 100644
index 000000000000..2adea6bda439
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py
@@ -0,0 +1,252 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+from datasets import load_dataset
+from parameterized import parameterized
+
+from diffusers import AutoencoderOobleck
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ floats_tensor,
+ slow,
+ torch_all_close,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderOobleck
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_oobleck_config(self, block_out_channels=None):
+ init_dict = {
+ "encoder_hidden_size": 12,
+ "decoder_channels": 12,
+ "decoder_input_channels": 6,
+ "audio_channels": 2,
+ "downsampling_ratios": [2, 4],
+ "channel_multiples": [1, 2],
+ }
+ return init_dict
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 2
+ seq_len = 24
+
+ waveform = floats_tensor((batch_size, num_channels, seq_len)).to(torch_device)
+
+ return {"sample": waveform, "sample_posterior": False}
+
+ @property
+ def input_shape(self):
+ return (2, 24)
+
+ @property
+ def output_shape(self):
+ return (2, 24)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_oobleck_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_enable_disable_slicing(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE slicing should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ "Without slicing outputs should match with the outputs when slicing is manually disabled.",
+ )
+
+ @unittest.skip("Test unsupported.")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ @unittest.skip("No attention module used in this model")
+ def test_set_attn_processor_for_determinism(self):
+ return
+
+ @unittest.skip(
+ "Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'"
+ )
+ def test_layerwise_casting_training(self):
+ return super().test_layerwise_casting_training()
+
+ @unittest.skip(
+ "The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
+ "cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
+ "1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
+ "2. Unskip this test."
+ )
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @unittest.skip(
+ "The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
+ "cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
+ "1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
+ "2. Unskip this test."
+ )
+ def test_layerwise_casting_memory(self):
+ pass
+
+
+@slow
+class AutoencoderOobleckIntegrationTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def _load_datasamples(self, num_samples):
+ ds = load_dataset(
+ "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True
+ )
+ # automatic decoding with librispeech
+ speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
+
+ return torch.nn.utils.rnn.pad_sequence(
+ [torch.from_numpy(x["array"]) for x in speech_samples], batch_first=True
+ )
+
+ def get_audio(self, audio_sample_size=2097152, fp16=False):
+ dtype = torch.float16 if fp16 else torch.float32
+ audio = self._load_datasamples(2).to(torch_device).to(dtype)
+
+ # pad / crop to audio_sample_size
+ audio = torch.nn.functional.pad(audio[:, :audio_sample_size], pad=(0, audio_sample_size - audio.shape[-1]))
+
+ # todo channel
+ audio = audio.unsqueeze(1).repeat(1, 2, 1).to(torch_device)
+
+ return audio
+
+ def get_oobleck_vae_model(self, model_id="stabilityai/stable-audio-open-1.0", fp16=False):
+ torch_dtype = torch.float16 if fp16 else torch.float32
+
+ model = AutoencoderOobleck.from_pretrained(
+ model_id,
+ subfolder="vae",
+ torch_dtype=torch_dtype,
+ )
+ model.to(torch_device)
+
+ return model
+
+ def get_generator(self, seed=0):
+ generator_device = "cpu" if not torch_device.startswith(torch_device) else torch_device
+ if torch_device != "mps":
+ return torch.Generator(device=generator_device).manual_seed(seed)
+ return torch.manual_seed(seed)
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
+ [44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
+ # fmt: on
+ ]
+ )
+ def test_stable_diffusion(self, seed, expected_slice, expected_mean_absolute_diff):
+ model = self.get_oobleck_vae_model()
+ audio = self.get_audio()
+ generator = self.get_generator(seed)
+
+ with torch.no_grad():
+ sample = model(audio, generator=generator, sample_posterior=True).sample
+
+ assert sample.shape == audio.shape
+ assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
+
+ output_slice = sample[-1, 1, 5:10].cpu()
+ expected_output_slice = torch.tensor(expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
+
+ def test_stable_diffusion_mode(self):
+ model = self.get_oobleck_vae_model()
+ audio = self.get_audio()
+
+ with torch.no_grad():
+ sample = model(audio, sample_posterior=False).sample
+
+ assert sample.shape == audio.shape
+
+ @parameterized.expand(
+ [
+ # fmt: off
+ [33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
+ [44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
+ # fmt: on
+ ]
+ )
+ def test_stable_diffusion_encode_decode(self, seed, expected_slice, expected_mean_absolute_diff):
+ model = self.get_oobleck_vae_model()
+ audio = self.get_audio()
+ generator = self.get_generator(seed)
+
+ with torch.no_grad():
+ x = audio
+ posterior = model.encode(x).latent_dist
+ z = posterior.sample(generator=generator)
+ sample = model.decode(z).sample
+
+ # (batch_size, latent_dim, sequence_length)
+ assert posterior.mean.shape == (audio.shape[0], model.config.decoder_input_channels, 1024)
+
+ assert sample.shape == audio.shape
+ assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
+
+ output_slice = sample[-1, 1, 5:10].cpu()
+ expected_output_slice = torch.tensor(expected_slice)
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py
new file mode 100644
index 000000000000..bfbfb7ab8593
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py
@@ -0,0 +1,267 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import gc
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from diffusers import AutoencoderTiny
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ floats_tensor,
+ load_hf_numpy,
+ slow,
+ torch_all_close,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderTiny
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_tiny_config(self, block_out_channels=None):
+ block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
+ init_dict = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "encoder_block_out_channels": block_out_channels,
+ "decoder_block_out_channels": block_out_channels,
+ "num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
+ "num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
+ }
+ return init_dict
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_tiny_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skip("Model doesn't yet support smaller resolution.")
+ def test_enable_disable_tiling(self):
+ pass
+
+ def test_enable_disable_slicing(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict)[0]
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ output_with_slicing = model(**inputs_dict)[0]
+
+ self.assertLess(
+ (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE slicing should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ output_without_slicing_2 = model(**inputs_dict)[0]
+
+ self.assertEqual(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ "Without slicing outputs should match with the outputs when slicing is manually disabled.",
+ )
+
+ @unittest.skip("Test not supported.")
+ def test_outputs_equivalence(self):
+ pass
+
+ @unittest.skip("Test not supported.")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"DecoderTiny", "EncoderTiny"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ def test_effective_gradient_checkpointing(self):
+ if not self.model_class._supports_gradient_checkpointing:
+ return # Skip test if model does not support gradient checkpointing
+
+ # enable deterministic behavior for gradient checkpointing
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ inputs_dict_copy = copy.deepcopy(inputs_dict)
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ assert not model.is_gradient_checkpointing and model.training
+
+ out = model(**inputs_dict).sample
+ # run the backwards pass on the model. For backwards pass, for simplicity purpose,
+ # we won't calculate the loss and rather backprop on out.sum()
+ model.zero_grad()
+
+ labels = torch.randn_like(out)
+ loss = (out - labels).mean()
+ loss.backward()
+
+ # re-instantiate the model now enabling gradient checkpointing
+ torch.manual_seed(0)
+ model_2 = self.model_class(**init_dict)
+ # clone model
+ model_2.load_state_dict(model.state_dict())
+ model_2.to(torch_device)
+ model_2.enable_gradient_checkpointing()
+
+ assert model_2.is_gradient_checkpointing and model_2.training
+
+ out_2 = model_2(**inputs_dict_copy).sample
+ # run the backwards pass on the model. For backwards pass, for simplicity purpose,
+ # we won't calculate the loss and rather backprop on out.sum()
+ model_2.zero_grad()
+ loss_2 = (out_2 - labels).mean()
+ loss_2.backward()
+
+ # compare the output and parameters gradients
+ self.assertTrue((loss - loss_2).abs() < 1e-3)
+ named_params = dict(model.named_parameters())
+ named_params_2 = dict(model_2.named_parameters())
+
+ for name, param in named_params.items():
+ if "encoder.layers" in name:
+ continue
+ self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=3e-2))
+
+ @unittest.skip(
+ "The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
+ "1. Change the forward pass to be dtype agnostic.\n"
+ "2. Unskip this test."
+ )
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @unittest.skip(
+ "The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
+ "1. Change the forward pass to be dtype agnostic.\n"
+ "2. Unskip this test."
+ )
+ def test_layerwise_casting_memory(self):
+ pass
+
+
+@slow
+class AutoencoderTinyIntegrationTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_file_format(self, seed, shape):
+ return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
+
+ def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
+ dtype = torch.float16 if fp16 else torch.float32
+ image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
+ return image
+
+ def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
+ torch_dtype = torch.float16 if fp16 else torch.float32
+
+ model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
+ model.to(torch_device).eval()
+ return model
+
+ @parameterized.expand(
+ [
+ [(1, 4, 73, 97), (1, 3, 584, 776)],
+ [(1, 4, 97, 73), (1, 3, 776, 584)],
+ [(1, 4, 49, 65), (1, 3, 392, 520)],
+ [(1, 4, 65, 49), (1, 3, 520, 392)],
+ [(1, 4, 49, 49), (1, 3, 392, 392)],
+ ]
+ )
+ def test_tae_tiling(self, in_shape, out_shape):
+ model = self.get_sd_vae_model()
+ model.enable_tiling()
+ with torch.no_grad():
+ zeros = torch.zeros(in_shape).to(torch_device)
+ dec = model.decode(zeros).sample
+ assert dec.shape == out_shape
+
+ def test_stable_diffusion(self):
+ model = self.get_sd_vae_model()
+ image = self.get_sd_image(seed=33)
+
+ with torch.no_grad():
+ sample = model(image).sample
+
+ assert sample.shape == image.shape
+
+ output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
+ expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
+
+ assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
+
+ @parameterized.expand([(True,), (False,)])
+ def test_tae_roundtrip(self, enable_tiling):
+ # load the autoencoder
+ model = self.get_sd_vae_model()
+ if enable_tiling:
+ model.enable_tiling()
+
+ # make a black image with a white square in the middle,
+ # which is large enough to split across multiple tiles
+ image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
+ image[..., 256:768, 256:768] = 1.0
+
+ # round-trip the image through the autoencoder
+ with torch.no_grad():
+ sample = model(image).sample
+
+ # the autoencoder reconstruction should match original image, sorta
+ def downscale(x):
+ return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
+
+ assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py
new file mode 100644
index 000000000000..ffc474039889
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_wan.py
@@ -0,0 +1,79 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderKLWan
+from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLWan
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_wan_config(self):
+ return {
+ "base_dim": 3,
+ "z_dim": 16,
+ "dim_mult": [1, 1, 1, 1],
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (16, 16)
+
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_wan_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skip("Gradient checkpointing has not been implemented yet")
+ def test_gradient_checkpointing_is_applied(self):
+ pass
+
+ @unittest.skip("Test not supported")
+ def test_forward_with_norm_groups(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
+ def test_layerwise_casting_training(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py
new file mode 100644
index 000000000000..77977a78d83b
--- /dev/null
+++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py
@@ -0,0 +1,300 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+
+from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ load_image,
+ slow,
+ torch_all_close,
+ torch_device,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
+ model_class = ConsistencyDecoderVAE
+ main_input_name = "sample"
+ base_precision = 1e-2
+ forward_requires_fresh_args = True
+
+ def get_consistency_vae_config(self, block_out_channels=None, norm_num_groups=None):
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
+ return {
+ "encoder_block_out_channels": block_out_channels,
+ "encoder_in_channels": 3,
+ "encoder_out_channels": 4,
+ "encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
+ "decoder_add_attention": False,
+ "decoder_block_out_channels": block_out_channels,
+ "decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
+ "decoder_downsample_padding": 1,
+ "decoder_in_channels": 7,
+ "decoder_layers_per_block": 1,
+ "decoder_norm_eps": 1e-05,
+ "decoder_norm_num_groups": norm_num_groups,
+ "encoder_norm_num_groups": norm_num_groups,
+ "decoder_num_train_timesteps": 1024,
+ "decoder_out_channels": 6,
+ "decoder_resnet_time_scale_shift": "scale_shift",
+ "decoder_time_embedding_type": "learned",
+ "decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
+ "scaling_factor": 1,
+ "latent_channels": 4,
+ }
+
+ def inputs_dict(self, seed=None):
+ if seed is None:
+ generator = torch.Generator("cpu").manual_seed(0)
+ else:
+ generator = torch.Generator("cpu").manual_seed(seed)
+ image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
+
+ return {"sample": image, "generator": generator}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def init_dict(self):
+ return self.get_consistency_vae_config()
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return self.init_dict, self.inputs_dict()
+
+ def test_enable_disable_tiling(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+ _ = inputs_dict.pop("generator")
+
+ torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_tiling()
+ output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE tiling should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ "Without tiling outputs should match with the outputs when tiling is manually disabled.",
+ )
+
+ def test_enable_disable_slicing(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ inputs_dict.update({"return_dict": False})
+ _ = inputs_dict.pop("generator")
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertLess(
+ (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
+ 0.5,
+ "VAE slicing should not affect the inference results",
+ )
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
+
+ self.assertEqual(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ "Without slicing outputs should match with the outputs when slicing is manually disabled.",
+ )
+
+
+@slow
+class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
+ def setUp(self):
+ # clean up the VRAM before each test
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @torch.no_grad()
+ def test_encode_decode(self):
+ vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
+ vae.to(torch_device)
+
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ ).resize((256, 256))
+ image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :].to(
+ torch_device
+ )
+
+ latent = vae.encode(image).latent_dist.mean
+
+ sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
+
+ actual_output = sample[0, :2, :2, :2].flatten().cpu()
+ expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
+
+ assert torch_all_close(actual_output, expected_output, atol=5e-3)
+
+ def test_sd(self):
+ vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
+ pipe = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, safety_checker=None
+ )
+ pipe.to(torch_device)
+
+ out = pipe(
+ "horse",
+ num_inference_steps=2,
+ output_type="pt",
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).images[0]
+
+ actual_output = out[:2, :2, :2].flatten().cpu()
+ expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
+
+ assert torch_all_close(actual_output, expected_output, atol=5e-3)
+
+ def test_encode_decode_f16(self):
+ vae = ConsistencyDecoderVAE.from_pretrained(
+ "openai/consistency-decoder", torch_dtype=torch.float16
+ ) # TODO - update
+ vae.to(torch_device)
+
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ ).resize((256, 256))
+ image = (
+ torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
+ .half()
+ .to(torch_device)
+ )
+
+ latent = vae.encode(image).latent_dist.mean
+
+ sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
+
+ actual_output = sample[0, :2, :2, :2].flatten().cpu()
+ expected_output = torch.tensor(
+ [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
+ dtype=torch.float16,
+ )
+
+ assert torch_all_close(actual_output, expected_output, atol=5e-3)
+
+ def test_sd_f16(self):
+ vae = ConsistencyDecoderVAE.from_pretrained(
+ "openai/consistency-decoder", torch_dtype=torch.float16
+ ) # TODO - update
+ pipe = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ vae=vae,
+ safety_checker=None,
+ )
+ pipe.to(torch_device)
+
+ out = pipe(
+ "horse",
+ num_inference_steps=2,
+ output_type="pt",
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).images[0]
+
+ actual_output = out[:2, :2, :2].flatten().cpu()
+ expected_output = torch.tensor(
+ [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
+ dtype=torch.float16,
+ )
+
+ assert torch_all_close(actual_output, expected_output, atol=5e-3)
+
+ def test_vae_tiling(self):
+ vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
+ pipe = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ out_1 = pipe(
+ "horse",
+ num_inference_steps=2,
+ output_type="pt",
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).images[0]
+
+ # make sure tiled vae decode yields the same result
+ pipe.enable_vae_tiling()
+ out_2 = pipe(
+ "horse",
+ num_inference_steps=2,
+ output_type="pt",
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).images[0]
+
+ assert torch_all_close(out_1, out_2, atol=5e-3)
+
+ # test that tiled decode works with various shapes
+ shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
+ with torch.no_grad():
+ for shape in shapes:
+ image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype)
+ pipe.vae.decode(image)
diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py
deleted file mode 100644
index 0188f9121ae0..000000000000
--- a/tests/models/autoencoders/test_models_vae.py
+++ /dev/null
@@ -1,1270 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from datasets import load_dataset
-from parameterized import parameterized
-
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderKLTemporalDecoder,
- AutoencoderOobleck,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- StableDiffusionPipeline,
-)
-from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.loading_utils import load_image
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_hf_numpy,
- require_torch_accelerator,
- require_torch_accelerator_with_fp16,
- require_torch_accelerator_with_training,
- require_torch_gpu,
- skip_mps,
- slow,
- torch_all_close,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
-
-
-enable_full_determinism()
-
-
-def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [2, 4]
- norm_num_groups = norm_num_groups or 2
- init_dict = {
- "block_out_channels": block_out_channels,
- "in_channels": 3,
- "out_channels": 3,
- "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
- "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
- "latent_channels": 4,
- "norm_num_groups": norm_num_groups,
- }
- return init_dict
-
-
-def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [2, 4]
- norm_num_groups = norm_num_groups or 2
- init_dict = {
- "in_channels": 3,
- "out_channels": 3,
- "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
- "down_block_out_channels": block_out_channels,
- "layers_per_down_block": 1,
- "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
- "up_block_out_channels": block_out_channels,
- "layers_per_up_block": 1,
- "act_fn": "silu",
- "latent_channels": 4,
- "norm_num_groups": norm_num_groups,
- "sample_size": 32,
- "scaling_factor": 0.18215,
- }
- return init_dict
-
-
-def get_autoencoder_tiny_config(block_out_channels=None):
- block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
- init_dict = {
- "in_channels": 3,
- "out_channels": 3,
- "encoder_block_out_channels": block_out_channels,
- "decoder_block_out_channels": block_out_channels,
- "num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
- "num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
- }
- return init_dict
-
-
-def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [2, 4]
- norm_num_groups = norm_num_groups or 2
- return {
- "encoder_block_out_channels": block_out_channels,
- "encoder_in_channels": 3,
- "encoder_out_channels": 4,
- "encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
- "decoder_add_attention": False,
- "decoder_block_out_channels": block_out_channels,
- "decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
- "decoder_downsample_padding": 1,
- "decoder_in_channels": 7,
- "decoder_layers_per_block": 1,
- "decoder_norm_eps": 1e-05,
- "decoder_norm_num_groups": norm_num_groups,
- "encoder_norm_num_groups": norm_num_groups,
- "decoder_num_train_timesteps": 1024,
- "decoder_out_channels": 6,
- "decoder_resnet_time_scale_shift": "scale_shift",
- "decoder_time_embedding_type": "learned",
- "decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
- "scaling_factor": 1,
- "latent_channels": 4,
- }
-
-
-def get_autoencoder_oobleck_config(block_out_channels=None):
- init_dict = {
- "encoder_hidden_size": 12,
- "decoder_channels": 12,
- "decoder_input_channels": 6,
- "audio_channels": 2,
- "downsampling_ratios": [2, 4],
- "channel_multiples": [1, 2],
- }
- return init_dict
-
-
-class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
- model_class = AutoencoderKL
- main_input_name = "sample"
- base_precision = 1e-2
-
- @property
- def dummy_input(self):
- batch_size = 4
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
-
- return {"sample": image}
-
- @property
- def input_shape(self):
- return (3, 32, 32)
-
- @property
- def output_shape(self):
- return (3, 32, 32)
-
- def prepare_init_args_and_inputs_for_common(self):
- init_dict = get_autoencoder_kl_config()
- inputs_dict = self.dummy_input
- return init_dict, inputs_dict
-
- def test_forward_signature(self):
- pass
-
- def test_training(self):
- pass
-
- @require_torch_accelerator_with_training
- def test_gradient_checkpointing(self):
- # enable deterministic behavior for gradient checkpointing
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
- model.to(torch_device)
-
- assert not model.is_gradient_checkpointing and model.training
-
- out = model(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model.zero_grad()
-
- labels = torch.randn_like(out)
- loss = (out - labels).mean()
- loss.backward()
-
- # re-instantiate the model now enabling gradient checkpointing
- model_2 = self.model_class(**init_dict)
- # clone model
- model_2.load_state_dict(model.state_dict())
- model_2.to(torch_device)
- model_2.enable_gradient_checkpointing()
-
- assert model_2.is_gradient_checkpointing and model_2.training
-
- out_2 = model_2(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model_2.zero_grad()
- loss_2 = (out_2 - labels).mean()
- loss_2.backward()
-
- # compare the output and parameters gradients
- self.assertTrue((loss - loss_2).abs() < 1e-5)
- named_params = dict(model.named_parameters())
- named_params_2 = dict(model_2.named_parameters())
- for name, param in named_params.items():
- self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
-
- def test_from_pretrained_hub(self):
- model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
- self.assertIsNotNone(model)
- self.assertEqual(len(loading_info["missing_keys"]), 0)
-
- model.to(torch_device)
- image = model(**self.dummy_input)
-
- assert image is not None, "Make sure output is not None"
-
- def test_output_pretrained(self):
- model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
- model = model.to(torch_device)
- model.eval()
-
- # Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
- generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
- if torch_device != "mps":
- generator = torch.Generator(device=generator_device).manual_seed(0)
- else:
- generator = torch.manual_seed(0)
-
- image = torch.randn(
- 1,
- model.config.in_channels,
- model.config.sample_size,
- model.config.sample_size,
- generator=torch.manual_seed(0),
- )
- image = image.to(torch_device)
- with torch.no_grad():
- output = model(image, sample_posterior=True, generator=generator).sample
-
- output_slice = output[0, -1, -3:, -3:].flatten().cpu()
-
- # Since the VAE Gaussian prior's generator is seeded on the appropriate device,
- # the expected output slices are not the same for CPU and GPU.
- if torch_device == "mps":
- expected_output_slice = torch.tensor(
- [
- -4.0078e-01,
- -3.8323e-04,
- -1.2681e-01,
- -1.1462e-01,
- 2.0095e-01,
- 1.0893e-01,
- -8.8247e-02,
- -3.0361e-01,
- -9.8644e-03,
- ]
- )
- elif generator_device == "cpu":
- expected_output_slice = torch.tensor(
- [
- -0.1352,
- 0.0878,
- 0.0419,
- -0.0818,
- -0.1069,
- 0.0688,
- -0.1458,
- -0.4446,
- -0.0026,
- ]
- )
- else:
- expected_output_slice = torch.tensor(
- [
- -0.2421,
- 0.4642,
- 0.2507,
- -0.0438,
- 0.0682,
- 0.3160,
- -0.2018,
- -0.0727,
- 0.2485,
- ]
- )
-
- self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
-
-
-class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
- model_class = AsymmetricAutoencoderKL
- main_input_name = "sample"
- base_precision = 1e-2
-
- @property
- def dummy_input(self):
- batch_size = 4
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
- mask = torch.ones((batch_size, 1) + sizes).to(torch_device)
-
- return {"sample": image, "mask": mask}
-
- @property
- def input_shape(self):
- return (3, 32, 32)
-
- @property
- def output_shape(self):
- return (3, 32, 32)
-
- def prepare_init_args_and_inputs_for_common(self):
- init_dict = get_asym_autoencoder_kl_config()
- inputs_dict = self.dummy_input
- return init_dict, inputs_dict
-
- def test_forward_signature(self):
- pass
-
- def test_forward_with_norm_groups(self):
- pass
-
-
-class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
- model_class = AutoencoderTiny
- main_input_name = "sample"
- base_precision = 1e-2
-
- @property
- def dummy_input(self):
- batch_size = 4
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
-
- return {"sample": image}
-
- @property
- def input_shape(self):
- return (3, 32, 32)
-
- @property
- def output_shape(self):
- return (3, 32, 32)
-
- def prepare_init_args_and_inputs_for_common(self):
- init_dict = get_autoencoder_tiny_config()
- inputs_dict = self.dummy_input
- return init_dict, inputs_dict
-
- def test_outputs_equivalence(self):
- pass
-
-
-class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
- model_class = ConsistencyDecoderVAE
- main_input_name = "sample"
- base_precision = 1e-2
- forward_requires_fresh_args = True
-
- def inputs_dict(self, seed=None):
- if seed is None:
- generator = torch.Generator("cpu").manual_seed(0)
- else:
- generator = torch.Generator("cpu").manual_seed(seed)
- image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
-
- return {"sample": image, "generator": generator}
-
- @property
- def input_shape(self):
- return (3, 32, 32)
-
- @property
- def output_shape(self):
- return (3, 32, 32)
-
- @property
- def init_dict(self):
- return get_consistency_vae_config()
-
- def prepare_init_args_and_inputs_for_common(self):
- return self.init_dict, self.inputs_dict()
-
- @unittest.skip
- def test_training(self):
- ...
-
- @unittest.skip
- def test_ema_training(self):
- ...
-
-
-class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase):
- model_class = AutoencoderKLTemporalDecoder
- main_input_name = "sample"
- base_precision = 1e-2
-
- @property
- def dummy_input(self):
- batch_size = 3
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
- num_frames = 3
-
- return {"sample": image, "num_frames": num_frames}
-
- @property
- def input_shape(self):
- return (3, 32, 32)
-
- @property
- def output_shape(self):
- return (3, 32, 32)
-
- def prepare_init_args_and_inputs_for_common(self):
- init_dict = {
- "block_out_channels": [32, 64],
- "in_channels": 3,
- "out_channels": 3,
- "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
- "latent_channels": 4,
- "layers_per_block": 2,
- }
- inputs_dict = self.dummy_input
- return init_dict, inputs_dict
-
- def test_forward_signature(self):
- pass
-
- def test_training(self):
- pass
-
- @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
- def test_gradient_checkpointing(self):
- # enable deterministic behavior for gradient checkpointing
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
- model.to(torch_device)
-
- assert not model.is_gradient_checkpointing and model.training
-
- out = model(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model.zero_grad()
-
- labels = torch.randn_like(out)
- loss = (out - labels).mean()
- loss.backward()
-
- # re-instantiate the model now enabling gradient checkpointing
- model_2 = self.model_class(**init_dict)
- # clone model
- model_2.load_state_dict(model.state_dict())
- model_2.to(torch_device)
- model_2.enable_gradient_checkpointing()
-
- assert model_2.is_gradient_checkpointing and model_2.training
-
- out_2 = model_2(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model_2.zero_grad()
- loss_2 = (out_2 - labels).mean()
- loss_2.backward()
-
- # compare the output and parameters gradients
- self.assertTrue((loss - loss_2).abs() < 1e-5)
- named_params = dict(model.named_parameters())
- named_params_2 = dict(model_2.named_parameters())
- for name, param in named_params.items():
- if "post_quant_conv" in name:
- continue
-
- self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
-
-
-class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
- model_class = AutoencoderOobleck
- main_input_name = "sample"
- base_precision = 1e-2
-
- @property
- def dummy_input(self):
- batch_size = 4
- num_channels = 2
- seq_len = 24
-
- waveform = floats_tensor((batch_size, num_channels, seq_len)).to(torch_device)
-
- return {"sample": waveform, "sample_posterior": False}
-
- @property
- def input_shape(self):
- return (2, 24)
-
- @property
- def output_shape(self):
- return (2, 24)
-
- def prepare_init_args_and_inputs_for_common(self):
- init_dict = get_autoencoder_oobleck_config()
- inputs_dict = self.dummy_input
- return init_dict, inputs_dict
-
- def test_forward_signature(self):
- pass
-
- def test_forward_with_norm_groups(self):
- pass
-
- @unittest.skip("No attention module used in this model")
- def test_set_attn_processor_for_determinism(self):
- return
-
-
-@slow
-class AutoencoderTinyIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_file_format(self, seed, shape):
- return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
-
- def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
- dtype = torch.float16 if fp16 else torch.float32
- image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
- return image
-
- def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=False):
- torch_dtype = torch.float16 if fp16 else torch.float32
-
- model = AutoencoderTiny.from_pretrained(model_id, torch_dtype=torch_dtype)
- model.to(torch_device).eval()
- return model
-
- @parameterized.expand(
- [
- [(1, 4, 73, 97), (1, 3, 584, 776)],
- [(1, 4, 97, 73), (1, 3, 776, 584)],
- [(1, 4, 49, 65), (1, 3, 392, 520)],
- [(1, 4, 65, 49), (1, 3, 520, 392)],
- [(1, 4, 49, 49), (1, 3, 392, 392)],
- ]
- )
- def test_tae_tiling(self, in_shape, out_shape):
- model = self.get_sd_vae_model()
- model.enable_tiling()
- with torch.no_grad():
- zeros = torch.zeros(in_shape).to(torch_device)
- dec = model.decode(zeros).sample
- assert dec.shape == out_shape
-
- def test_stable_diffusion(self):
- model = self.get_sd_vae_model()
- image = self.get_sd_image(seed=33)
-
- with torch.no_grad():
- sample = model(image).sample
-
- assert sample.shape == image.shape
-
- output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
- expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
-
- assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
-
- @parameterized.expand([(True,), (False,)])
- def test_tae_roundtrip(self, enable_tiling):
- # load the autoencoder
- model = self.get_sd_vae_model()
- if enable_tiling:
- model.enable_tiling()
-
- # make a black image with a white square in the middle,
- # which is large enough to split across multiple tiles
- image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
- image[..., 256:768, 256:768] = 1.0
-
- # round-trip the image through the autoencoder
- with torch.no_grad():
- sample = model(image).sample
-
- # the autoencoder reconstruction should match original image, sorta
- def downscale(x):
- return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
-
- assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
-
-
-@slow
-class AutoencoderKLIntegrationTests(unittest.TestCase):
- def get_file_format(self, seed, shape):
- return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
- dtype = torch.float16 if fp16 else torch.float32
- image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
- return image
-
- def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
- revision = "fp16" if fp16 else None
- torch_dtype = torch.float16 if fp16 else torch.float32
-
- model = AutoencoderKL.from_pretrained(
- model_id,
- subfolder="vae",
- torch_dtype=torch_dtype,
- revision=revision,
- )
- model.to(torch_device)
-
- return model
-
- def get_generator(self, seed=0):
- generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
- if torch_device != "mps":
- return torch.Generator(device=generator_device).manual_seed(seed)
- return torch.manual_seed(seed)
-
- @parameterized.expand(
- [
- # fmt: off
- [
- 33,
- [-0.1556, 0.9848, -0.0410, -0.0642, -0.2685, 0.8381, -0.2004, -0.0700],
- [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824],
- ],
- [
- 47,
- [-0.2376, 0.1200, 0.1337, -0.4830, -0.2504, -0.0759, -0.0486, -0.4077],
- [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131],
- ],
- # fmt: on
- ]
- )
- def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
- model = self.get_sd_vae_model()
- image = self.get_sd_image(seed)
- generator = self.get_generator(seed)
-
- with torch.no_grad():
- sample = model(image, generator=generator, sample_posterior=True).sample
-
- assert sample.shape == image.shape
-
- output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
- expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
-
- @parameterized.expand(
- [
- # fmt: off
- [33, [-0.0513, 0.0289, 1.3799, 0.2166, -0.2573, -0.0871, 0.5103, -0.0999]],
- [47, [-0.4128, -0.1320, -0.3704, 0.1965, -0.4116, -0.2332, -0.3340, 0.2247]],
- # fmt: on
- ]
- )
- @require_torch_accelerator_with_fp16
- def test_stable_diffusion_fp16(self, seed, expected_slice):
- model = self.get_sd_vae_model(fp16=True)
- image = self.get_sd_image(seed, fp16=True)
- generator = self.get_generator(seed)
-
- with torch.no_grad():
- sample = model(image, generator=generator, sample_posterior=True).sample
-
- assert sample.shape == image.shape
-
- output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
- expected_output_slice = torch.tensor(expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=1e-2)
-
- @parameterized.expand(
- [
- # fmt: off
- [
- 33,
- [-0.1609, 0.9866, -0.0487, -0.0777, -0.2716, 0.8368, -0.2055, -0.0814],
- [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824],
- ],
- [
- 47,
- [-0.2377, 0.1147, 0.1333, -0.4841, -0.2506, -0.0805, -0.0491, -0.4085],
- [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131],
- ],
- # fmt: on
- ]
- )
- def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
- model = self.get_sd_vae_model()
- image = self.get_sd_image(seed)
-
- with torch.no_grad():
- sample = model(image).sample
-
- assert sample.shape == image.shape
-
- output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
- expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
-
- @parameterized.expand(
- [
- # fmt: off
- [13, [-0.2051, -0.1803, -0.2311, -0.2114, -0.3292, -0.3574, -0.2953, -0.3323]],
- [37, [-0.2632, -0.2625, -0.2199, -0.2741, -0.4539, -0.4990, -0.3720, -0.4925]],
- # fmt: on
- ]
- )
- @require_torch_accelerator
- @skip_mps
- def test_stable_diffusion_decode(self, seed, expected_slice):
- model = self.get_sd_vae_model()
- encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
-
- with torch.no_grad():
- sample = model.decode(encoding).sample
-
- assert list(sample.shape) == [3, 3, 512, 512]
-
- output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
- expected_output_slice = torch.tensor(expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
-
- @parameterized.expand(
- [
- # fmt: off
- [27, [-0.0369, 0.0207, -0.0776, -0.0682, -0.1747, -0.1930, -0.1465, -0.2039]],
- [16, [-0.1628, -0.2134, -0.2747, -0.2642, -0.3774, -0.4404, -0.3687, -0.4277]],
- # fmt: on
- ]
- )
- @require_torch_accelerator_with_fp16
- def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
- model = self.get_sd_vae_model(fp16=True)
- encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
-
- with torch.no_grad():
- sample = model.decode(encoding).sample
-
- assert list(sample.shape) == [3, 3, 512, 512]
-
- output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
- expected_output_slice = torch.tensor(expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
-
- @parameterized.expand([(13,), (16,), (27,)])
- @require_torch_gpu
- @unittest.skipIf(
- not is_xformers_available(),
- reason="xformers is not required when using PyTorch 2.0.",
- )
- def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
- model = self.get_sd_vae_model(fp16=True)
- encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
-
- with torch.no_grad():
- sample = model.decode(encoding).sample
-
- model.enable_xformers_memory_efficient_attention()
- with torch.no_grad():
- sample_2 = model.decode(encoding).sample
-
- assert list(sample.shape) == [3, 3, 512, 512]
-
- assert torch_all_close(sample, sample_2, atol=1e-1)
-
- @parameterized.expand([(13,), (16,), (37,)])
- @require_torch_gpu
- @unittest.skipIf(
- not is_xformers_available(),
- reason="xformers is not required when using PyTorch 2.0.",
- )
- def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
- model = self.get_sd_vae_model()
- encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
-
- with torch.no_grad():
- sample = model.decode(encoding).sample
-
- model.enable_xformers_memory_efficient_attention()
- with torch.no_grad():
- sample_2 = model.decode(encoding).sample
-
- assert list(sample.shape) == [3, 3, 512, 512]
-
- assert torch_all_close(sample, sample_2, atol=1e-2)
-
- @parameterized.expand(
- [
- # fmt: off
- [33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
- [47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
- # fmt: on
- ]
- )
- def test_stable_diffusion_encode_sample(self, seed, expected_slice):
- model = self.get_sd_vae_model()
- image = self.get_sd_image(seed)
- generator = self.get_generator(seed)
-
- with torch.no_grad():
- dist = model.encode(image).latent_dist
- sample = dist.sample(generator=generator)
-
- assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
-
- output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
- expected_output_slice = torch.tensor(expected_slice)
-
- tolerance = 3e-3 if torch_device != "mps" else 1e-2
- assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
-
-
-@slow
-class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
- def get_file_format(self, seed, shape):
- return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
- dtype = torch.float16 if fp16 else torch.float32
- image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
- return image
-
- def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
- revision = "main"
- torch_dtype = torch.float32
-
- model = AsymmetricAutoencoderKL.from_pretrained(
- model_id,
- torch_dtype=torch_dtype,
- revision=revision,
- )
- model.to(torch_device).eval()
-
- return model
-
- def get_generator(self, seed=0):
- generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
- if torch_device != "mps":
- return torch.Generator(device=generator_device).manual_seed(seed)
- return torch.manual_seed(seed)
-
- @parameterized.expand(
- [
- # fmt: off
- [
- 33,
- [-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
- [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
- ],
- [
- 47,
- [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
- [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
- ],
- # fmt: on
- ]
- )
- def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
- model = self.get_sd_vae_model()
- image = self.get_sd_image(seed)
- generator = self.get_generator(seed)
-
- with torch.no_grad():
- sample = model(image, generator=generator, sample_posterior=True).sample
-
- assert sample.shape == image.shape
-
- output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
- expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
-
- @parameterized.expand(
- [
- # fmt: off
- [
- 33,
- [-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097],
- [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078],
- ],
- [
- 47,
- [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
- [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
- ],
- # fmt: on
- ]
- )
- def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
- model = self.get_sd_vae_model()
- image = self.get_sd_image(seed)
-
- with torch.no_grad():
- sample = model(image).sample
-
- assert sample.shape == image.shape
-
- output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
- expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
-
- @parameterized.expand(
- [
- # fmt: off
- [13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
- [37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
- # fmt: on
- ]
- )
- @require_torch_accelerator
- @skip_mps
- def test_stable_diffusion_decode(self, seed, expected_slice):
- model = self.get_sd_vae_model()
- encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
-
- with torch.no_grad():
- sample = model.decode(encoding).sample
-
- assert list(sample.shape) == [3, 3, 512, 512]
-
- output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
- expected_output_slice = torch.tensor(expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=2e-3)
-
- @parameterized.expand([(13,), (16,), (37,)])
- @require_torch_gpu
- @unittest.skipIf(
- not is_xformers_available(),
- reason="xformers is not required when using PyTorch 2.0.",
- )
- def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
- model = self.get_sd_vae_model()
- encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
-
- with torch.no_grad():
- sample = model.decode(encoding).sample
-
- model.enable_xformers_memory_efficient_attention()
- with torch.no_grad():
- sample_2 = model.decode(encoding).sample
-
- assert list(sample.shape) == [3, 3, 512, 512]
-
- assert torch_all_close(sample, sample_2, atol=5e-2)
-
- @parameterized.expand(
- [
- # fmt: off
- [33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
- [47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
- # fmt: on
- ]
- )
- def test_stable_diffusion_encode_sample(self, seed, expected_slice):
- model = self.get_sd_vae_model()
- image = self.get_sd_image(seed)
- generator = self.get_generator(seed)
-
- with torch.no_grad():
- dist = model.encode(image).latent_dist
- sample = dist.sample(generator=generator)
-
- assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
-
- output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
- expected_output_slice = torch.tensor(expected_slice)
-
- tolerance = 3e-3 if torch_device != "mps" else 1e-2
- assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
-
-
-@slow
-class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- @torch.no_grad()
- def test_encode_decode(self):
- vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
- vae.to(torch_device)
-
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/img2img/sketch-mountains-input.jpg"
- ).resize((256, 256))
- image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :].to(
- torch_device
- )
-
- latent = vae.encode(image).latent_dist.mean
-
- sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
-
- actual_output = sample[0, :2, :2, :2].flatten().cpu()
- expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
-
- assert torch_all_close(actual_output, expected_output, atol=5e-3)
-
- def test_sd(self):
- vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
- pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, safety_checker=None
- )
- pipe.to(torch_device)
-
- out = pipe(
- "horse",
- num_inference_steps=2,
- output_type="pt",
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
-
- actual_output = out[:2, :2, :2].flatten().cpu()
- expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
-
- assert torch_all_close(actual_output, expected_output, atol=5e-3)
-
- def test_encode_decode_f16(self):
- vae = ConsistencyDecoderVAE.from_pretrained(
- "openai/consistency-decoder", torch_dtype=torch.float16
- ) # TODO - update
- vae.to(torch_device)
-
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/img2img/sketch-mountains-input.jpg"
- ).resize((256, 256))
- image = (
- torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
- .half()
- .to(torch_device)
- )
-
- latent = vae.encode(image).latent_dist.mean
-
- sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
-
- actual_output = sample[0, :2, :2, :2].flatten().cpu()
- expected_output = torch.tensor(
- [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
- dtype=torch.float16,
- )
-
- assert torch_all_close(actual_output, expected_output, atol=5e-3)
-
- def test_sd_f16(self):
- vae = ConsistencyDecoderVAE.from_pretrained(
- "openai/consistency-decoder", torch_dtype=torch.float16
- ) # TODO - update
- pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
- vae=vae,
- safety_checker=None,
- )
- pipe.to(torch_device)
-
- out = pipe(
- "horse",
- num_inference_steps=2,
- output_type="pt",
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
-
- actual_output = out[:2, :2, :2].flatten().cpu()
- expected_output = torch.tensor(
- [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
- dtype=torch.float16,
- )
-
- assert torch_all_close(actual_output, expected_output, atol=5e-3)
-
- def test_vae_tiling(self):
- vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
- pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
- )
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- out_1 = pipe(
- "horse",
- num_inference_steps=2,
- output_type="pt",
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
-
- # make sure tiled vae decode yields the same result
- pipe.enable_vae_tiling()
- out_2 = pipe(
- "horse",
- num_inference_steps=2,
- output_type="pt",
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
-
- assert torch_all_close(out_1, out_2, atol=5e-3)
-
- # test that tiled decode works with various shapes
- shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
- with torch.no_grad():
- for shape in shapes:
- image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype)
- pipe.vae.decode(image)
-
-
-@slow
-class AutoencoderOobleckIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def _load_datasamples(self, num_samples):
- ds = load_dataset(
- "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True
- )
- # automatic decoding with librispeech
- speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
-
- return torch.nn.utils.rnn.pad_sequence(
- [torch.from_numpy(x["array"]) for x in speech_samples], batch_first=True
- )
-
- def get_audio(self, audio_sample_size=2097152, fp16=False):
- dtype = torch.float16 if fp16 else torch.float32
- audio = self._load_datasamples(2).to(torch_device).to(dtype)
-
- # pad / crop to audio_sample_size
- audio = torch.nn.functional.pad(audio[:, :audio_sample_size], pad=(0, audio_sample_size - audio.shape[-1]))
-
- # todo channel
- audio = audio.unsqueeze(1).repeat(1, 2, 1).to(torch_device)
-
- return audio
-
- def get_oobleck_vae_model(self, model_id="stabilityai/stable-audio-open-1.0", fp16=False):
- torch_dtype = torch.float16 if fp16 else torch.float32
-
- model = AutoencoderOobleck.from_pretrained(
- model_id,
- subfolder="vae",
- torch_dtype=torch_dtype,
- )
- model.to(torch_device)
-
- return model
-
- def get_generator(self, seed=0):
- generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
- if torch_device != "mps":
- return torch.Generator(device=generator_device).manual_seed(seed)
- return torch.manual_seed(seed)
-
- @parameterized.expand(
- [
- # fmt: off
- [33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
- [44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
- # fmt: on
- ]
- )
- def test_stable_diffusion(self, seed, expected_slice, expected_mean_absolute_diff):
- model = self.get_oobleck_vae_model()
- audio = self.get_audio()
- generator = self.get_generator(seed)
-
- with torch.no_grad():
- sample = model(audio, generator=generator, sample_posterior=True).sample
-
- assert sample.shape == audio.shape
- assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
-
- output_slice = sample[-1, 1, 5:10].cpu()
- expected_output_slice = torch.tensor(expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
-
- def test_stable_diffusion_mode(self):
- model = self.get_oobleck_vae_model()
- audio = self.get_audio()
-
- with torch.no_grad():
- sample = model(audio, sample_posterior=False).sample
-
- assert sample.shape == audio.shape
-
- @parameterized.expand(
- [
- # fmt: off
- [33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192],
- [44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196],
- # fmt: on
- ]
- )
- def test_stable_diffusion_encode_decode(self, seed, expected_slice, expected_mean_absolute_diff):
- model = self.get_oobleck_vae_model()
- audio = self.get_audio()
- generator = self.get_generator(seed)
-
- with torch.no_grad():
- x = audio
- posterior = model.encode(x).latent_dist
- z = posterior.sample(generator=generator)
- sample = model.decode(z).sample
-
- # (batch_size, latent_dim, sequence_length)
- assert posterior.mean.shape == (audio.shape[0], model.config.decoder_input_channels, 1024)
-
- assert sample.shape == audio.shape
- assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6
-
- output_slice = sample[-1, 1, 5:10].cpu()
- expected_output_slice = torch.tensor(expected_slice)
-
- assert torch_all_close(output_slice, expected_output_slice, atol=1e-5)
diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py
index c61ae1bdf0ff..77abe139d785 100644
--- a/tests/models/autoencoders/test_models_vq.py
+++ b/tests/models/autoencoders/test_models_vq.py
@@ -65,9 +65,11 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
+ @unittest.skip("Test not supported.")
def test_forward_signature(self):
pass
+ @unittest.skip("Test not supported.")
def test_training(self):
pass
diff --git a/tests/models/autoencoders/vae.py b/tests/models/autoencoders/vae.py
new file mode 100644
index 000000000000..f8055f1c1cb0
--- /dev/null
+++ b/tests/models/autoencoders/vae.py
@@ -0,0 +1,86 @@
+def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
+ init_dict = {
+ "block_out_channels": block_out_channels,
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
+ "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
+ "latent_channels": 4,
+ "norm_num_groups": norm_num_groups,
+ }
+ return init_dict
+
+
+def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
+ init_dict = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
+ "down_block_out_channels": block_out_channels,
+ "layers_per_down_block": 1,
+ "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
+ "up_block_out_channels": block_out_channels,
+ "layers_per_up_block": 1,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": norm_num_groups,
+ "sample_size": 32,
+ "scaling_factor": 0.18215,
+ }
+ return init_dict
+
+
+def get_autoencoder_tiny_config(block_out_channels=None):
+ block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
+ init_dict = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "encoder_block_out_channels": block_out_channels,
+ "decoder_block_out_channels": block_out_channels,
+ "num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
+ "num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
+ }
+ return init_dict
+
+
+def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
+ return {
+ "encoder_block_out_channels": block_out_channels,
+ "encoder_in_channels": 3,
+ "encoder_out_channels": 4,
+ "encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
+ "decoder_add_attention": False,
+ "decoder_block_out_channels": block_out_channels,
+ "decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
+ "decoder_downsample_padding": 1,
+ "decoder_in_channels": 7,
+ "decoder_layers_per_block": 1,
+ "decoder_norm_eps": 1e-05,
+ "decoder_norm_num_groups": norm_num_groups,
+ "encoder_norm_num_groups": norm_num_groups,
+ "decoder_num_train_timesteps": 1024,
+ "decoder_out_channels": 6,
+ "decoder_resnet_time_scale_shift": "scale_shift",
+ "decoder_time_embedding_type": "learned",
+ "decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
+ "scaling_factor": 1,
+ "latent_channels": 4,
+ }
+
+
+def get_autoencoder_oobleck_config(block_out_channels=None):
+ init_dict = {
+ "encoder_hidden_size": 12,
+ "decoder_channels": 12,
+ "decoder_input_channels": 6,
+ "audio_channels": 2,
+ "downsampling_ratios": [2, 4],
+ "channel_multiples": [1, 2],
+ }
+ return init_dict
diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py
index 2489604274b4..d070f6ea33e3 100644
--- a/tests/models/test_attention_processor.py
+++ b/tests/models/test_attention_processor.py
@@ -2,10 +2,12 @@
import unittest
import numpy as np
+import pytest
import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
+from diffusers.utils.testing_utils import torch_device
class AttnAddedKVProcessorTests(unittest.TestCase):
@@ -79,6 +81,15 @@ def test_only_cross_attention(self):
class DeprecatedAttentionBlockTests(unittest.TestCase):
+ @pytest.fixture(scope="session")
+ def is_dist_enabled(pytestconfig):
+ return pytestconfig.getoption("dist") == "loadfile"
+
+ @pytest.mark.xfail(
+ condition=torch.device(torch_device).type == "cuda" and is_dist_enabled,
+ reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.",
+ strict=True,
+ )
def test_conversion_when_using_device_map(self):
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 5548fdd0723d..fc4a3128dd9f 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -13,26 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
+import gc
import inspect
import json
import os
+import re
import tempfile
import traceback
import unittest
import unittest.mock as mock
import uuid
-from typing import Dict, List, Tuple
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import requests_mock
import torch
-from accelerate.utils import compute_module_sizes
+import torch.nn as nn
+from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from requests.exceptions import HTTPError
-from diffusers.models import UNet2DConditionModel
+from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
@@ -43,6 +48,7 @@
from diffusers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME,
+ is_peft_available,
is_torch_npu_available,
is_xformers_available,
logging,
@@ -52,17 +58,25 @@
CaptureLogger,
get_python_version,
is_torch_compile,
+ numpy_cosine_similarity_distance,
require_torch_2,
+ require_torch_accelerator,
require_torch_accelerator_with_training,
require_torch_gpu,
- require_torch_multi_gpu,
+ require_torch_multi_accelerator,
run_test_in_subprocess,
+ torch_all_close,
torch_device,
)
+from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ..others.test_utils import TOKEN, USER, is_staging_test
+if is_peft_available():
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+
def caculate_expected_num_shards(index_map_path):
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
@@ -72,6 +86,16 @@ def caculate_expected_num_shards(index_map_path):
return expected_num_shards
+def check_if_lora_correctly_set(model) -> bool:
+ """
+ Checks if the LoRA layers are correctly set with peft
+ """
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ return True
+ return False
+
+
# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None
@@ -96,16 +120,92 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
out_queue.join()
+def named_persistent_module_tensors(
+ module: nn.Module,
+ recurse: bool = False,
+):
+ """
+ A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
+
+ Args:
+ module (`torch.nn.Module`):
+ The module we want the tensors on.
+ recurse (`bool`, *optional`, defaults to `False`):
+ Whether or not to go look in every submodule or just return the direct parameters and buffers.
+ """
+ yield from module.named_parameters(recurse=recurse)
+
+ for named_buffer in module.named_buffers(recurse=recurse):
+ name, _ = named_buffer
+ # Get parent by splitting on dots and traversing the model
+ parent = module
+ if "." in name:
+ parent_name = name.rsplit(".", 1)[0]
+ for part in parent_name.split("."):
+ parent = getattr(parent, part)
+ name = name.split(".")[-1]
+ if name not in parent._non_persistent_buffers_set:
+ yield named_buffer
+
+
+def compute_module_persistent_sizes(
+ model: nn.Module,
+ dtype: Optional[Union[str, torch.device]] = None,
+ special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
+):
+ """
+ Compute the size of each submodule of a given model (parameters + persistent buffers).
+ """
+ if dtype is not None:
+ dtype = _get_proper_dtype(dtype)
+ dtype_size = dtype_byte_size(dtype)
+ if special_dtypes is not None:
+ special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
+ special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
+ module_sizes = defaultdict(int)
+
+ module_list = []
+
+ module_list = named_persistent_module_tensors(model, recurse=True)
+
+ for name, tensor in module_list:
+ if special_dtypes is not None and name in special_dtypes:
+ size = tensor.numel() * special_dtypes_size[name]
+ elif dtype is None:
+ size = tensor.numel() * dtype_byte_size(tensor.dtype)
+ elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
+ # According to the code in set_module_tensor_to_device, these types won't be converted
+ # so use their original size here
+ size = tensor.numel() * dtype_byte_size(tensor.dtype)
+ else:
+ size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
+ name_parts = name.split(".")
+ for idx in range(len(name_parts) + 1):
+ module_sizes[".".join(name_parts[:idx])] += size
+
+ return module_sizes
+
+
+def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype):
+ if torch.is_tensor(maybe_tensor):
+ return maybe_tensor.to(target_dtype) if maybe_tensor.dtype == current_dtype else maybe_tensor
+ if isinstance(maybe_tensor, dict):
+ return {k: cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for k, v in maybe_tensor.items()}
+ if isinstance(maybe_tensor, list):
+ return [cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for v in maybe_tensor]
+ return maybe_tensor
+
+
class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
- def test_accelerate_loading_error_message(self):
- with self.assertRaises(ValueError) as error_context:
+ def test_missing_key_loading_warning_message(self):
+ with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
# make sure that error message states what keys are missing
- assert "conv_out.bias" in str(error_context.exception)
+ assert "conv_out.bias" in " ".join(logs.output)
@parameterized.expand(
[
@@ -234,6 +334,58 @@ def test_weight_overwrite(self):
assert model.config.in_channels == 9
+ @require_torch_gpu
+ def test_keep_modules_in_fp32(self):
+ r"""
+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
+ Also ensures if inference works.
+ """
+ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
+
+ for torch_dtype in [torch.bfloat16, torch.float16]:
+ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
+
+ model = SD3Transformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-sd3-pipe", subfolder="transformer", torch_dtype=torch_dtype
+ ).to(torch_device)
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ self.assertTrue(module.weight.dtype == torch.float32)
+ else:
+ self.assertTrue(module.weight.dtype == torch_dtype)
+
+ def get_dummy_inputs():
+ batch_size = 2
+ num_channels = 4
+ height = width = embedding_dim = 32
+ pooled_embedding_dim = embedding_dim * 2
+ sequence_length = 154
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_prompt_embeds,
+ "timestep": timestep,
+ }
+
+ # test if inference works.
+ with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch_dtype):
+ input_dict_for_transformer = get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ _ = model(**model_inputs)
+
+ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
+
class UNetTesterMixin:
def test_forward_with_norm_groups(self):
@@ -458,7 +610,7 @@ def test_set_xformers_attn_processor_for_determinism(self):
assert torch.allclose(output, output_3, atol=self.base_precision)
assert torch.allclose(output_2, output_3, atol=self.base_precision)
- @require_torch_gpu
+ @require_torch_accelerator
def test_set_attn_processor_for_determinism(self):
if self.uses_custom_attn_processor:
return
@@ -587,8 +739,14 @@ def test_from_save_pretrained_dtype(self):
model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
- new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
- assert new_model.dtype == dtype
+ if (
+ hasattr(self.model_class, "_keep_in_fp32_modules")
+ and self.model_class._keep_in_fp32_modules is None
+ ):
+ new_model = self.model_class.from_pretrained(
+ tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
+ )
+ assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
if self.forward_requires_fresh_args:
@@ -785,6 +943,91 @@ def test_enable_disable_gradient_checkpointing(self):
model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing)
+ @require_torch_accelerator_with_training
+ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
+ if not self.model_class._supports_gradient_checkpointing:
+ return # Skip test if model does not support gradient checkpointing
+
+ # enable deterministic behavior for gradient checkpointing
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ inputs_dict_copy = copy.deepcopy(inputs_dict)
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ assert not model.is_gradient_checkpointing and model.training
+
+ out = model(**inputs_dict).sample
+ # run the backwards pass on the model. For backwards pass, for simplicity purpose,
+ # we won't calculate the loss and rather backprop on out.sum()
+ model.zero_grad()
+
+ labels = torch.randn_like(out)
+ loss = (out - labels).mean()
+ loss.backward()
+
+ # re-instantiate the model now enabling gradient checkpointing
+ torch.manual_seed(0)
+ model_2 = self.model_class(**init_dict)
+ # clone model
+ model_2.load_state_dict(model.state_dict())
+ model_2.to(torch_device)
+ model_2.enable_gradient_checkpointing()
+
+ assert model_2.is_gradient_checkpointing and model_2.training
+
+ out_2 = model_2(**inputs_dict_copy).sample
+ # run the backwards pass on the model. For backwards pass, for simplicity purpose,
+ # we won't calculate the loss and rather backprop on out.sum()
+ model_2.zero_grad()
+ loss_2 = (out_2 - labels).mean()
+ loss_2.backward()
+
+ # compare the output and parameters gradients
+ self.assertTrue((loss - loss_2).abs() < loss_tolerance)
+ named_params = dict(model.named_parameters())
+ named_params_2 = dict(model_2.named_parameters())
+
+ for name, param in named_params.items():
+ if "post_quant_conv" in name:
+ continue
+ if name in skip:
+ continue
+ # TODO(aryan): remove the below lines after looking into easyanimate transformer a little more
+ # It currently errors out the gradient checkpointing test because the gradients for attn2.to_out is None
+ if param.grad is None:
+ continue
+ self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
+
+ @unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
+ def test_gradient_checkpointing_is_applied(
+ self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
+ ):
+ if not self.model_class._supports_gradient_checkpointing:
+ return # Skip test if model does not support gradient checkpointing
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ if attention_head_dim is not None:
+ init_dict["attention_head_dim"] = attention_head_dim
+ if num_attention_heads is not None:
+ init_dict["num_attention_heads"] = num_attention_heads
+ if block_out_channels is not None:
+ init_dict["block_out_channels"] = block_out_channels
+
+ model_class_copy = copy.copy(self.model_class)
+ model = model_class_copy(**init_dict)
+ model.enable_gradient_checkpointing()
+
+ modules_with_gc_enabled = {}
+ for submodule in model.modules():
+ if hasattr(submodule, "gradient_checkpointing"):
+ self.assertTrue(submodule.gradient_checkpointing)
+ modules_with_gc_enabled[submodule.__class__.__name__] = True
+
+ assert set(modules_with_gc_enabled.keys()) == expected_set
+ assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
@@ -805,7 +1048,95 @@ def test_deprecated_kwargs(self):
" from `_deprecated_kwargs = []`"
)
- @require_torch_gpu
+ @parameterized.expand([True, False])
+ @torch.no_grad()
+ @unittest.skipIf(not is_peft_available(), "Only with PEFT")
+ def test_save_load_lora_adapter(self, use_dora=False):
+ import safetensors
+ from peft import LoraConfig
+ from peft.utils import get_peft_model_state_dict
+
+ from diffusers.loaders.peft import PeftAdapterMixin
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ if not issubclass(model.__class__, PeftAdapterMixin):
+ return
+
+ torch.manual_seed(0)
+ output_no_lora = model(**inputs_dict, return_dict=False)[0]
+
+ denoiser_lora_config = LoraConfig(
+ r=4,
+ lora_alpha=4,
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
+ init_lora_weights=False,
+ use_dora=use_dora,
+ )
+ model.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
+
+ torch.manual_seed(0)
+ outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
+
+ self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model.save_lora_adapter(tmpdir)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+
+ model.unload_lora()
+ self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
+
+ model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
+ state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
+
+ for k in state_dict_loaded:
+ loaded_v = state_dict_loaded[k]
+ retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
+ self.assertTrue(torch.allclose(loaded_v, retrieved_v))
+
+ self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
+
+ torch.manual_seed(0)
+ outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
+
+ self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
+ self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
+
+ @unittest.skipIf(not is_peft_available(), "Only with PEFT")
+ def test_wrong_adapter_name_raises_error(self):
+ from peft import LoraConfig
+
+ from diffusers.loaders.peft import PeftAdapterMixin
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ if not issubclass(model.__class__, PeftAdapterMixin):
+ return
+
+ denoiser_lora_config = LoraConfig(
+ r=4,
+ lora_alpha=4,
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
+ init_lora_weights=False,
+ use_dora=False,
+ )
+ model.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ wrong_name = "foo"
+ with self.assertRaises(ValueError) as err_context:
+ model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
+
+ self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
+
+ @require_torch_accelerator
def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
@@ -835,7 +1166,7 @@ def test_cpu_offload(self):
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
- @require_torch_gpu
+ @require_torch_accelerator
def test_disk_offload_without_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
@@ -848,17 +1179,16 @@ def test_disk_offload_without_safetensors(self):
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
+ max_size = int(self.model_split_percents[0] * model_size)
+ # Force disk offload by setting very small CPU memory
+ max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
+
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
-
with self.assertRaises(ValueError):
- max_size = int(self.model_split_percents[0] * model_size)
- max_memory = {0: max_size, "cpu": max_size}
# This errors out because it's missing an offload folder
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
- max_size = int(self.model_split_percents[0] * model_size)
- max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)
@@ -869,7 +1199,7 @@ def test_disk_offload_without_safetensors(self):
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
- @require_torch_gpu
+ @require_torch_accelerator
def test_disk_offload_with_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
@@ -897,7 +1227,7 @@ def test_disk_offload_with_safetensors(self):
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
- @require_torch_multi_gpu
+ @require_torch_multi_accelerator
def test_model_parallelism(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
@@ -920,6 +1250,7 @@ def test_model_parallelism(self):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
+ print(f" new_model.hf_device_map:{new_model.hf_device_map}")
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
@@ -928,7 +1259,7 @@ def test_model_parallelism(self):
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
- @require_torch_gpu
+ @require_torch_accelerator
def test_sharded_checkpoints(self):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -937,7 +1268,7 @@ def test_sharded_checkpoints(self):
base_output = model(**inputs_dict)
- model_size = compute_module_sizes(model)[""]
+ model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -960,7 +1291,7 @@ def test_sharded_checkpoints(self):
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
- @require_torch_gpu
+ @require_torch_accelerator
def test_sharded_checkpoints_with_variant(self):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -969,7 +1300,7 @@ def test_sharded_checkpoints_with_variant(self):
base_output = model(**inputs_dict)
- model_size = compute_module_sizes(model)[""]
+ model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -998,7 +1329,7 @@ def test_sharded_checkpoints_with_variant(self):
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
- @require_torch_gpu
+ @require_torch_accelerator
def test_sharded_checkpoints_device_map(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
@@ -1009,7 +1340,7 @@ def test_sharded_checkpoints_device_map(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)
- model_size = compute_module_sizes(model)[""]
+ model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
@@ -1038,7 +1369,7 @@ def test_variant_sharded_ckpt_right_format(self):
config, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
- model_size = compute_module_sizes(model)[""]
+ model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1068,6 +1399,175 @@ def test_variant_sharded_ckpt_right_format(self):
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
+ def test_layerwise_casting_training(self):
+ def test_fn(storage_dtype, compute_dtype):
+ if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
+ return
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ model = model.to(torch_device, dtype=compute_dtype)
+ model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
+ model.train()
+
+ inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
+ with torch.amp.autocast(device_type=torch.device(torch_device).type):
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ input_tensor = inputs_dict[self.main_input_name]
+ noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
+ noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
+ loss = torch.nn.functional.mse_loss(output, noise)
+
+ loss.backward()
+
+ test_fn(torch.float16, torch.float32)
+ test_fn(torch.float8_e4m3fn, torch.float32)
+ test_fn(torch.float8_e5m2, torch.float32)
+ test_fn(torch.float8_e4m3fn, torch.bfloat16)
+
+ def test_layerwise_casting_inference(self):
+ from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
+
+ torch.manual_seed(0)
+ config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**config).eval()
+ model = model.to(torch_device)
+ base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
+
+ def check_linear_dtype(module, storage_dtype, compute_dtype):
+ patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
+ if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
+ patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
+ for name, submodule in module.named_modules():
+ if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
+ continue
+ dtype_to_check = storage_dtype
+ if any(re.search(pattern, name) for pattern in patterns_to_check):
+ dtype_to_check = compute_dtype
+ if getattr(submodule, "weight", None) is not None:
+ self.assertEqual(submodule.weight.dtype, dtype_to_check)
+ if getattr(submodule, "bias", None) is not None:
+ self.assertEqual(submodule.bias.dtype, dtype_to_check)
+
+ def test_layerwise_casting(storage_dtype, compute_dtype):
+ torch.manual_seed(0)
+ config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
+ model = self.model_class(**config).eval()
+ model = model.to(torch_device, dtype=compute_dtype)
+ model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
+
+ check_linear_dtype(model, storage_dtype, compute_dtype)
+ output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy()
+
+ # The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
+ # We just want to make sure that the layerwise casting is working as expected.
+ self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0)
+
+ test_layerwise_casting(torch.float16, torch.float32)
+ test_layerwise_casting(torch.float8_e4m3fn, torch.float32)
+ test_layerwise_casting(torch.float8_e5m2, torch.float32)
+ test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
+
+ @require_torch_gpu
+ def test_layerwise_casting_memory(self):
+ MB_TOLERANCE = 0.2
+ LEAST_COMPUTE_CAPABILITY = 8.0
+
+ def reset_memory_stats():
+ gc.collect()
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ def get_memory_usage(storage_dtype, compute_dtype):
+ torch.manual_seed(0)
+ config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
+ model = self.model_class(**config).eval()
+ model = model.to(torch_device, dtype=compute_dtype)
+ model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
+
+ reset_memory_stats()
+ model(**inputs_dict)
+ model_memory_footprint = model.get_memory_footprint()
+ peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2
+
+ return model_memory_footprint, peak_inference_memory_allocated_mb
+
+ fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
+ fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
+ fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
+ torch.float8_e4m3fn, torch.bfloat16
+ )
+
+ compute_capability = get_torch_cuda_device_capability()
+ self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
+ # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
+ # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
+ if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
+ self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
+ # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
+ # bytes. This only happens for some models, so we allow a small tolerance.
+ # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
+ self.assertTrue(
+ fp8_e4m3_fp32_max_memory < fp32_max_memory
+ or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
+ )
+
+ @require_torch_gpu
+ def test_group_offloading(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ torch.manual_seed(0)
+
+ @torch.no_grad()
+ def run_forward(model):
+ self.assertTrue(
+ all(
+ module._diffusers_hook.get_hook("group_offloading") is not None
+ for module in model.modules()
+ if hasattr(module, "_diffusers_hook")
+ )
+ )
+ model.eval()
+ return model(**inputs_dict)[0]
+
+ model = self.model_class(**init_dict)
+ if not getattr(model, "_supports_group_offloading", True):
+ return
+
+ model.to(torch_device)
+ output_without_group_offloading = run_forward(model)
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
+ output_with_group_offloading1 = run_forward(model)
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
+ output_with_group_offloading2 = run_forward(model)
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.enable_group_offload(torch_device, offload_type="leaf_level")
+ output_with_group_offloading3 = run_forward(model)
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
+ output_with_group_offloading4 = run_forward(model)
+
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
+
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py
index b12cae1a8879..5f4a2f587e92 100644
--- a/tests/models/transformers/test_models_dit_transformer2d.py
+++ b/tests/models/transformers/test_models_dit_transformer2d.py
@@ -84,6 +84,13 @@ def test_correct_class_remapping_from_dict_config(self):
model = Transformer2DModel.from_config(init_dict)
assert isinstance(model, DiTTransformer2DModel)
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"DiTTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ def test_effective_gradient_checkpointing(self):
+ super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)
+
def test_correct_class_remapping_from_pretrained_config(self):
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
model = Transformer2DModel.from_config(config)
diff --git a/tests/models/transformers/test_models_pixart_transformer2d.py b/tests/models/transformers/test_models_pixart_transformer2d.py
index 30293f5d35cb..a544a3fc4607 100644
--- a/tests/models/transformers/test_models_pixart_transformer2d.py
+++ b/tests/models/transformers/test_models_pixart_transformer2d.py
@@ -92,6 +92,10 @@ def test_output(self):
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"PixArtTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
def test_correct_class_remapping_from_dict_config(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = Transformer2DModel.from_config(init_dict)
diff --git a/tests/models/transformers/test_models_prior.py b/tests/models/transformers/test_models_prior.py
index d2ed10dfa1f6..471c1084c00c 100644
--- a/tests/models/transformers/test_models_prior.py
+++ b/tests/models/transformers/test_models_prior.py
@@ -132,7 +132,6 @@ def test_output_pretrained(self):
output = model(**input)[0]
output_slice = output[0, :5].flatten().cpu()
- print(output_slice)
# Since the VAE Gaussian prior's generator is seeded on the appropriate device,
# the expected output slices are not the same for CPU and GPU.
@@ -182,7 +181,6 @@ def test_kandinsky_prior(self, seed, expected_slice):
assert list(sample.shape) == [1, 768]
output_slice = sample[0, :8].flatten().cpu()
- print(output_slice)
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py
new file mode 100644
index 000000000000..3479803da61d
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_allegro.py
@@ -0,0 +1,83 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import AllegroTransformer3DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = AllegroTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ num_frames = 2
+ height = 8
+ width = 8
+ embedding_dim = 16
+ sequence_length = 16
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim // 2)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 2, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (4, 2, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_layers": 1,
+ "cross_attention_dim": 16,
+ "sample_width": 8,
+ "sample_height": 8,
+ "sample_frames": 8,
+ "caption_channels": 8,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"AllegroTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py
index 376d8b57da4d..d1ff7d2c96d3 100644
--- a/tests/models/transformers/test_models_transformer_aura_flow.py
+++ b/tests/models/transformers/test_models_transformer_aura_flow.py
@@ -74,6 +74,10 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"AuraFlowTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
def test_set_attn_processor_for_determinism(self):
pass
diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py
index 6db4113cbd1b..2b3cca883d17 100644
--- a/tests/models/transformers/test_models_transformer_cogvideox.py
+++ b/tests/models/transformers/test_models_transformer_cogvideox.py
@@ -33,6 +33,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
+ model_split_percents = [0.7, 0.7, 0.8]
@property
def dummy_input(self):
@@ -71,13 +72,78 @@ def prepare_init_args_and_inputs_for_common(self):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
- "num_layers": 1,
+ "num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
"patch_size": 2,
+ "patch_size_t": None,
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"CogVideoXTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = CogVideoXTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ num_frames = 2
+ height = 8
+ width = 8
+ embedding_dim = 8
+ sequence_length = 8
+
+ hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (1, 4, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (1, 4, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "in_channels": 4,
+ "out_channels": 4,
+ "time_embed_dim": 2,
+ "text_embed_dim": 8,
+ "num_layers": 2,
+ "sample_width": 8,
+ "sample_height": 8,
+ "sample_frames": 8,
+ "patch_size": 2,
+ "patch_size_t": 2,
+ "temporal_compression_ratio": 4,
+ "max_text_seq_length": 8,
+ "use_rotary_positional_embeddings": True,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"CogVideoXTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py
index 46612dbd9190..91c7c35fbd07 100644
--- a/tests/models/transformers/test_models_transformer_cogview3plus.py
+++ b/tests/models/transformers/test_models_transformer_cogview3plus.py
@@ -33,6 +33,7 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView3PlusTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
+ model_split_percents = [0.7, 0.6, 0.6]
@property
def dummy_input(self):
@@ -71,7 +72,7 @@ def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
- "num_layers": 1,
+ "num_layers": 2,
"attention_head_dim": 4,
"num_attention_heads": 2,
"out_channels": 4,
@@ -83,3 +84,7 @@ def prepare_init_args_and_inputs_for_common(self):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"CogView3PlusTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_cogview4.py b/tests/models/transformers/test_models_transformer_cogview4.py
new file mode 100644
index 000000000000..e311ce77ea50
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_cogview4.py
@@ -0,0 +1,83 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import CogView4Transformer2DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = CogView4Transformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ height = 8
+ width = 8
+ embedding_dim = 8
+ sequence_length = 8
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
+ target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
+ crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "original_size": original_size,
+ "target_size": target_size,
+ "crop_coords": crop_coords,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (4, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 2,
+ "in_channels": 4,
+ "num_layers": 2,
+ "attention_head_dim": 4,
+ "num_attention_heads": 4,
+ "out_channels": 4,
+ "text_embed_dim": 8,
+ "time_embed_dim": 8,
+ "condition_dim": 4,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"CogView4Transformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_consisid.py b/tests/models/transformers/test_models_transformer_consisid.py
new file mode 100644
index 000000000000..b848ed014074
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_consisid.py
@@ -0,0 +1,105 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import ConsisIDTransformer3DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class ConsisIDTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = ConsisIDTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ num_frames = 1
+ height = 8
+ width = 8
+ embedding_dim = 8
+ sequence_length = 8
+
+ hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ id_vit_hidden = [torch.ones([batch_size, 2, 2]).to(torch_device)] * 1
+ id_cond = torch.ones(batch_size, 2).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "id_vit_hidden": id_vit_hidden,
+ "id_cond": id_cond,
+ }
+
+ @property
+ def input_shape(self):
+ return (1, 4, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (1, 4, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "in_channels": 4,
+ "out_channels": 4,
+ "time_embed_dim": 2,
+ "text_embed_dim": 8,
+ "num_layers": 1,
+ "sample_width": 8,
+ "sample_height": 8,
+ "sample_frames": 8,
+ "patch_size": 2,
+ "temporal_compression_ratio": 4,
+ "max_text_seq_length": 8,
+ "cross_attn_interval": 1,
+ "is_kps": False,
+ "is_train_face": True,
+ "cross_attn_dim_head": 1,
+ "cross_attn_num_heads": 1,
+ "LFE_id_dim": 2,
+ "LFE_vit_dim": 2,
+ "LFE_depth": 5,
+ "LFE_dim_head": 8,
+ "LFE_num_heads": 2,
+ "LFE_num_id_token": 1,
+ "LFE_num_querie": 1,
+ "LFE_output_dim": 10,
+ "LFE_ff_mult": 1,
+ "LFE_num_scale": 1,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"ConsisIDTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_easyanimate.py b/tests/models/transformers/test_models_transformer_easyanimate.py
new file mode 100644
index 000000000000..9f10a7da0a76
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_easyanimate.py
@@ -0,0 +1,87 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import EasyAnimateTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class EasyAnimateTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = EasyAnimateTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ embedding_dim = 16
+ sequence_length = 16
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "timestep_cond": None,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_hidden_states_t5": None,
+ "inpaint_latents": None,
+ "control_latents": None,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 2, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 2, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "in_channels": 4,
+ "mmdit_layers": 2,
+ "num_layers": 2,
+ "out_channels": 4,
+ "patch_size": 2,
+ "sample_height": 60,
+ "sample_width": 90,
+ "text_embed_dim": 16,
+ "time_embed_dim": 8,
+ "time_position_encoding_type": "3d_rope",
+ "timestep_activation_fn": "silu",
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"EasyAnimateTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 6cf7a4f75707..c88b3dac8216 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -18,6 +18,8 @@
import torch
from diffusers import FluxTransformer2DModel
+from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
+from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
@@ -26,6 +28,56 @@
enable_full_determinism()
+def create_flux_ip_adapter_state_dict(model):
+ # "ip_adapter" (cross-attention weights)
+ ip_cross_attn_state_dict = {}
+ key_id = 0
+
+ for name in model.attn_processors.keys():
+ if name.startswith("single_transformer_blocks"):
+ continue
+
+ joint_attention_dim = model.config["joint_attention_dim"]
+ hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
+ sd = FluxIPAdapterJointAttnProcessor2_0(
+ hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
+ ).state_dict()
+ ip_cross_attn_state_dict.update(
+ {
+ f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
+ f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
+ f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
+ f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
+ }
+ )
+
+ key_id += 1
+
+ # "image_proj" (ImageProjection layer weights)
+
+ image_projection = ImageProjection(
+ cross_attention_dim=model.config["joint_attention_dim"],
+ image_embed_dim=model.config["pooled_projection_dim"],
+ num_image_text_embeds=4,
+ )
+
+ ip_image_projection_state_dict = {}
+ sd = image_projection.state_dict()
+ ip_image_projection_state_dict.update(
+ {
+ "proj.weight": sd["image_embeds.weight"],
+ "proj.bias": sd["image_embeds.bias"],
+ "norm.weight": sd["norm.weight"],
+ "norm.bias": sd["norm.bias"],
+ }
+ )
+
+ del sd
+ ip_state_dict = {}
+ ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
+ return ip_state_dict
+
+
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
@@ -111,3 +163,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"FluxTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py
new file mode 100644
index 000000000000..495131ad6fd8
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py
@@ -0,0 +1,292 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import HunyuanVideoTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 1
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ pooled_projection_dim = 8
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+ pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
+ encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
+ guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_projections,
+ "encoder_attention_mask": encoder_attention_mask,
+ "guidance": guidance,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 10,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "num_refiner_layers": 1,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "guidance_embeds": True,
+ "text_embed_dim": 16,
+ "pooled_projection_dim": 8,
+ "rope_axes_dim": (2, 4, 4),
+ "image_condition_type": None,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"HunyuanVideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 8
+ num_frames = 1
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ pooled_projection_dim = 8
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+ pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
+ encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
+ guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_projections,
+ "encoder_attention_mask": encoder_attention_mask,
+ "guidance": guidance,
+ }
+
+ @property
+ def input_shape(self):
+ return (8, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 8,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 10,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "num_refiner_layers": 1,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "guidance_embeds": True,
+ "text_embed_dim": 16,
+ "pooled_projection_dim": 8,
+ "rope_axes_dim": (2, 4, 4),
+ "image_condition_type": None,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_output(self):
+ super().test_output(expected_output_shape=(1, *self.output_shape))
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"HunyuanVideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 2 * 4 + 1
+ num_frames = 1
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ pooled_projection_dim = 8
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+ pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
+ encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_projections,
+ "encoder_attention_mask": encoder_attention_mask,
+ }
+
+ @property
+ def input_shape(self):
+ return (8, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 2 * 4 + 1,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 10,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "num_refiner_layers": 1,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "guidance_embeds": False,
+ "text_embed_dim": 16,
+ "pooled_projection_dim": 8,
+ "rope_axes_dim": (2, 4, 4),
+ "image_condition_type": "latent_concat",
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_output(self):
+ super().test_output(expected_output_shape=(1, *self.output_shape))
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"HunyuanVideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 2
+ num_frames = 1
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ pooled_projection_dim = 8
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+ pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
+ encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
+ guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_projections,
+ "encoder_attention_mask": encoder_attention_mask,
+ "guidance": guidance,
+ }
+
+ @property
+ def input_shape(self):
+ return (8, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 2,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 10,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "num_refiner_layers": 1,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "guidance_embeds": True,
+ "text_embed_dim": 16,
+ "pooled_projection_dim": 8,
+ "rope_axes_dim": (2, 4, 4),
+ "image_condition_type": "token_replace",
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_output(self):
+ super().test_output(expected_output_shape=(1, *self.output_shape))
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"HunyuanVideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_latte.py b/tests/models/transformers/test_models_transformer_latte.py
index 3fe0a6098045..0cb9094f5165 100644
--- a/tests/models/transformers/test_models_transformer_latte.py
+++ b/tests/models/transformers/test_models_transformer_latte.py
@@ -86,3 +86,7 @@ def test_output(self):
super().test_output(
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"LatteTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py
new file mode 100644
index 000000000000..128bf04155e7
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_ltx.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import LTXVideoTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = LTXVideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ embedding_dim = 16
+ sequence_length = 16
+
+ hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "encoder_attention_mask": encoder_attention_mask,
+ "num_frames": num_frames,
+ "height": height,
+ "width": width,
+ }
+
+ @property
+ def input_shape(self):
+ return (512, 4)
+
+ @property
+ def output_shape(self):
+ return (512, 4)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "cross_attention_dim": 16,
+ "num_layers": 1,
+ "qk_norm": "rms_norm_across_heads",
+ "caption_channels": 16,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"LTXVideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py
new file mode 100644
index 000000000000..4db3ae68aa94
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_lumina2.py
@@ -0,0 +1,89 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import Lumina2Transformer2DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = Lumina2Transformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2 # N
+ num_channels = 4 # C
+ height = width = 16 # H, W
+ embedding_dim = 32 # D
+ sequence_length = 16 # L
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ timestep = torch.rand(size=(batch_size,)).to(torch_device)
+ attention_mask = torch.ones(size=(batch_size, sequence_length), dtype=torch.bool).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "encoder_attention_mask": attention_mask,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "sample_size": 16,
+ "patch_size": 2,
+ "in_channels": 4,
+ "hidden_size": 24,
+ "num_layers": 2,
+ "num_refiner_layers": 1,
+ "num_attention_heads": 3,
+ "num_kv_heads": 1,
+ "multiple_of": 2,
+ "ffn_dim_multiplier": None,
+ "norm_eps": 1e-5,
+ "scaling_factor": 1.0,
+ "axes_dim_rope": (4, 2, 2),
+ "axes_lens": (128, 128, 128),
+ "cap_feat_dim": 32,
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"Lumina2Transformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_mochi.py b/tests/models/transformers/test_models_transformer_mochi.py
new file mode 100644
index 000000000000..d284ab942949
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_mochi.py
@@ -0,0 +1,86 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import MochiTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = MochiTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+ # Overriding it because of the transformer size.
+ model_split_percents = [0.7, 0.6, 0.6]
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ embedding_dim = 16
+ sequence_length = 16
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "encoder_attention_mask": encoder_attention_mask,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 2, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 2, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 2,
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "num_layers": 2,
+ "pooled_projection_dim": 16,
+ "in_channels": 4,
+ "out_channels": None,
+ "qk_norm": "rms_norm",
+ "text_embed_dim": 16,
+ "time_embed_dim": 4,
+ "activation_fn": "swiglu",
+ "max_sequence_length": 16,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"MochiTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_omnigen.py b/tests/models/transformers/test_models_transformer_omnigen.py
new file mode 100644
index 000000000000..1bdcc68b0378
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_omnigen.py
@@ -0,0 +1,89 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import OmniGenTransformer2DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = OmniGenTransformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+ model_split_percents = [0.1, 0.1, 0.1]
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ height = 8
+ width = 8
+ sequence_length = 24
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ timestep = torch.rand(size=(batch_size,), dtype=hidden_states.dtype).to(torch_device)
+ input_ids = torch.randint(0, 10, (batch_size, sequence_length)).to(torch_device)
+ input_img_latents = [torch.randn((1, num_channels, height, width)).to(torch_device)]
+ input_image_sizes = {0: [[0, 0 + height * width // 2 // 2]]}
+
+ attn_seq_length = sequence_length + 1 + height * width // 2 // 2
+ attention_mask = torch.ones((batch_size, attn_seq_length, attn_seq_length)).to(torch_device)
+ position_ids = torch.LongTensor([list(range(attn_seq_length))] * batch_size).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "input_ids": input_ids,
+ "input_img_latents": input_img_latents,
+ "input_image_sizes": input_image_sizes,
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (4, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "hidden_size": 16,
+ "num_attention_heads": 4,
+ "num_key_value_heads": 4,
+ "intermediate_size": 32,
+ "num_layers": 20,
+ "pad_token_id": 0,
+ "vocab_size": 1000,
+ "in_channels": 4,
+ "time_step_dim": 4,
+ "rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"OmniGenTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_sana.py b/tests/models/transformers/test_models_transformer_sana.py
new file mode 100644
index 000000000000..d4dc30f5d7a8
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_sana.py
@@ -0,0 +1,83 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import SanaTransformer2DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = SanaTransformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+ model_split_percents = [0.7, 0.7, 0.9]
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ height = 32
+ width = 32
+ embedding_dim = 8
+ sequence_length = 8
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (4, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_layers": 1,
+ "attention_head_dim": 4,
+ "num_attention_heads": 2,
+ "num_cross_attention_heads": 2,
+ "cross_attention_head_dim": 4,
+ "cross_attention_dim": 8,
+ "caption_channels": 8,
+ "sample_size": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"SanaTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py
index 2be4744c5ac4..659d9a82fd76 100644
--- a/tests/models/transformers/test_models_transformer_sd3.py
+++ b/tests/models/transformers/test_models_transformer_sd3.py
@@ -18,6 +18,7 @@
import torch
from diffusers import SD3Transformer2DModel
+from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
@@ -32,6 +33,7 @@
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
+ model_split_percents = [0.8, 0.8, 0.9]
@property
def dummy_input(self):
@@ -66,7 +68,7 @@ def prepare_init_args_and_inputs_for_common(self):
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
- "num_layers": 1,
+ "num_layers": 4,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
@@ -80,14 +82,33 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_enable_works(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+
+ model.enable_xformers_memory_efficient_attention()
+
+ assert (
+ model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
+ ), "xformers is not enabled"
+
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"SD3Transformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
+ model_split_percents = [0.8, 0.8, 0.9]
@property
def dummy_input(self):
@@ -122,7 +143,7 @@ def prepare_init_args_and_inputs_for_common(self):
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
- "num_layers": 2,
+ "num_layers": 4,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
@@ -136,6 +157,44 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_enable_works(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+
+ model.enable_xformers_memory_efficient_attention()
+
+ assert (
+ model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
+ ), "xformers is not enabled"
+
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"SD3Transformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ def test_skip_layers(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ # Forward pass without skipping layers
+ output_full = model(**inputs_dict).sample
+
+ # Forward pass with skipping layers 0 (since there's only one layer in this test setup)
+ inputs_dict_with_skip = inputs_dict.copy()
+ inputs_dict_with_skip["skip_layers"] = [0]
+ output_skip = model(**inputs_dict_with_skip).sample
+
+ # Check that the outputs are different
+ self.assertFalse(
+ torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
+ )
+
+ # Check that the outputs have the same shape
+ self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py
new file mode 100644
index 000000000000..3ac64c628988
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_wan.py
@@ -0,0 +1,81 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import WanTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = WanTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"WanTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py
index 9f7ef3bca085..7e160f9c128b 100644
--- a/tests/models/unets/test_models_unet_1d.py
+++ b/tests/models/unets/test_models_unet_1d.py
@@ -15,6 +15,7 @@
import unittest
+import pytest
import torch
from diffusers import UNet1DModel
@@ -51,12 +52,18 @@ def input_shape(self):
def output_shape(self):
return (4, 14, 16)
+ @unittest.skip("Test not supported.")
def test_ema_training(self):
pass
+ @unittest.skip("Test not supported.")
def test_training(self):
pass
+ @unittest.skip("Test not supported.")
+ def test_layerwise_casting_training(self):
+ pass
+
def test_determinism(self):
super().test_determinism()
@@ -126,6 +133,7 @@ def test_output_pretrained(self):
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
+ @unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@@ -149,6 +157,28 @@ def test_unet_1d_maestro(self):
assert (output_sum - 224.0896).abs() < 0.5
assert (output_max - 0.0607).abs() < 4e-4
+ @pytest.mark.xfail(
+ reason=(
+ "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
+ "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
+ "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
+ "2. Unskip this test."
+ ),
+ )
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
+
+ @pytest.mark.xfail(
+ reason=(
+ "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
+ "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
+ "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
+ "2. Unskip this test."
+ ),
+ )
+ def test_layerwise_casting_memory(self):
+ pass
+
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
@@ -205,12 +235,18 @@ def test_output(self):
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+ @unittest.skip("Test not supported.")
def test_ema_training(self):
pass
+ @unittest.skip("Test not supported.")
def test_training(self):
pass
+ @unittest.skip("Test not supported.")
+ def test_layerwise_casting_training(self):
+ pass
+
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 14,
@@ -265,6 +301,29 @@ def test_output_pretrained(self):
# fmt: on
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
+ @unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
+
+ @pytest.mark.xfail(
+ reason=(
+ "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
+ "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
+ "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
+ "2. Unskip this test."
+ ),
+ )
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @pytest.mark.xfail(
+ reason=(
+ "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
+ "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
+ "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
+ "2. Unskip this test."
+ ),
+ )
+ def test_layerwise_casting_memory(self):
+ pass
diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py
index 5f827f274224..0e5fdc4bba2e 100644
--- a/tests/models/unets/test_models_unet_2d.py
+++ b/tests/models/unets/test_models_unet_2d.py
@@ -105,6 +105,52 @@ def test_mid_block_attn_groups(self):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+ def test_mid_block_none(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ mid_none_init_dict["mid_block_type"] = None
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ mid_none_model = self.model_class(**mid_none_init_dict)
+ mid_none_model.to(torch_device)
+ mid_none_model.eval()
+
+ self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ with torch.no_grad():
+ mid_none_output = mid_none_model(**mid_none_inputs_dict)
+
+ if isinstance(mid_none_output, dict):
+ mid_none_output = mid_none_output.to_tuple()[0]
+
+ self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "AttnUpBlock2D",
+ "AttnDownBlock2D",
+ "UNetMidBlock2D",
+ "UpBlock2D",
+ "DownBlock2D",
+ }
+
+ # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
+ attention_head_dim = 8
+ block_out_channels = (16, 32)
+
+ super().test_gradient_checkpointing_is_applied(
+ expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
+ )
+
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
@@ -220,6 +266,17 @@ def test_output_pretrained(self):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
+
+ # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
+ attention_head_dim = 32
+ block_out_channels = (32, 64)
+
+ super().test_gradient_checkpointing_is_applied(
+ expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
+ )
+
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
@@ -326,6 +383,33 @@ def test_output_pretrained_ve_large(self):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
+ @unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# not required for this model
pass
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "UNetMidBlock2D",
+ }
+
+ block_out_channels = (32, 64, 64, 64)
+
+ super().test_gradient_checkpointing_is_applied(
+ expected_set=expected_set, block_out_channels=block_out_channels
+ )
+
+ def test_effective_gradient_checkpointing(self):
+ super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
+
+ @unittest.skip(
+ "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
+ )
+ def test_layerwise_casting_inference(self):
+ pass
+
+ @unittest.skip(
+ "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
+ )
+ def test_layerwise_casting_memory(self):
+ pass
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index 37d55cedeb28..8e1187f11468 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -36,6 +36,9 @@
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
is_peft_available,
@@ -43,7 +46,6 @@
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
- require_torch_accelerator_with_training,
require_torch_gpu,
skip_mps,
slow,
@@ -176,8 +178,7 @@ def create_ip_adapter_plus_state_dict(model):
)
ip_image_projection_state_dict = OrderedDict()
- keys = [k for k in image_projection.state_dict() if "layers." in k]
- print(keys)
+
for k, v in image_projection.state_dict().items():
if "2.to" in k:
k = k.replace("2.to", "0.to")
@@ -406,47 +407,6 @@ def test_xformers_enable_works(self):
== "XFormersAttnProcessor"
), "xformers is not enabled"
- @require_torch_accelerator_with_training
- def test_gradient_checkpointing(self):
- # enable deterministic behavior for gradient checkpointing
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
- model.to(torch_device)
-
- assert not model.is_gradient_checkpointing and model.training
-
- out = model(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model.zero_grad()
-
- labels = torch.randn_like(out)
- loss = (out - labels).mean()
- loss.backward()
-
- # re-instantiate the model now enabling gradient checkpointing
- model_2 = self.model_class(**init_dict)
- # clone model
- model_2.load_state_dict(model.state_dict())
- model_2.to(torch_device)
- model_2.enable_gradient_checkpointing()
-
- assert model_2.is_gradient_checkpointing and model_2.training
-
- out_2 = model_2(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model_2.zero_grad()
- loss_2 = (out_2 - labels).mean()
- loss_2.backward()
-
- # compare the output and parameters gradients
- self.assertTrue((loss - loss_2).abs() < 1e-5)
- named_params = dict(model.named_parameters())
- named_params_2 = dict(model_2.named_parameters())
- for name, param in named_params.items():
- self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
-
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -599,31 +559,7 @@ def check_sliceable_dim_attr(module: torch.nn.Module):
check_sliceable_dim_attr(module)
def test_gradient_checkpointing_is_applied(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- init_dict["block_out_channels"] = (16, 32)
- init_dict["attention_head_dim"] = (8, 16)
-
- model_class_copy = copy.copy(self.model_class)
-
- modules_with_gc_enabled = {}
-
- # now monkey patch the following function:
- # def _set_gradient_checkpointing(self, module, value=False):
- # if hasattr(module, "gradient_checkpointing"):
- # module.gradient_checkpointing = value
-
- def _set_gradient_checkpointing_new(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
- modules_with_gc_enabled[module.__class__.__name__] = True
-
- model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
-
- model = model_class_copy(**init_dict)
- model.enable_gradient_checkpointing()
-
- EXPECTED_SET = {
+ expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
@@ -631,9 +567,11 @@ def _set_gradient_checkpointing_new(self, module, value=False):
"Transformer2DModel",
"DownBlock2D",
}
-
- assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
- assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+ attention_head_dim = (8, 16)
+ block_out_channels = (16, 32)
+ super().test_gradient_checkpointing_is_applied(
+ expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
+ )
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
@@ -1067,7 +1005,7 @@ def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
- @require_torch_gpu
+ @require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
@@ -1078,7 +1016,7 @@ def test_load_sharded_checkpoint_from_hub_local(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
- @require_torch_gpu
+ @require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
@@ -1089,7 +1027,7 @@ def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
- @require_torch_gpu
+ @require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
@@ -1104,7 +1042,7 @@ def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
- @require_torch_gpu
+ @require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
@@ -1119,7 +1057,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, va
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
- @require_torch_gpu
+ @require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
@@ -1129,7 +1067,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
- @require_torch_gpu
+ @require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
@@ -1142,30 +1080,7 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
- def test_lora(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
- model.to(torch_device)
-
- # forward pass without LoRA
- with torch.no_grad():
- non_lora_sample = model(**inputs_dict).sample
-
- unet_lora_config = get_unet_lora_config()
- model.add_adapter(unet_lora_config)
-
- assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
-
- # forward pass with LoRA
- with torch.no_grad():
- lora_sample = model(**inputs_dict).sample
-
- assert not torch.allclose(
- non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4
- ), "LoRA injected UNet should produce different results."
-
- @require_peft_backend
- def test_lora_serialization(self):
+ def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
@@ -1186,8 +1101,14 @@ def test_lora_serialization(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
- model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ with self.assertWarns(FutureWarning) as warning:
+ model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+
+ warning_message = str(warning.warnings[0].message)
+ assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
+
+ # import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
@@ -1200,6 +1121,24 @@ def test_lora_serialization(self):
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
), "Loading from a saved checkpoint should produce identical results."
+ @require_peft_backend
+ def test_save_attn_procs_raise_warning(self):
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ unet_lora_config = get_unet_lora_config()
+ model.add_adapter(unet_lora_config)
+
+ assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ with self.assertWarns(FutureWarning) as warning:
+ model.save_attn_procs(tmpdirname)
+
+ warning_message = str(warning.warnings[0].message)
+ assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
+
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
@@ -1228,11 +1167,11 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
return model
- @require_torch_gpu
+ @require_torch_accelerator
def test_set_attention_slice_auto(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
unet = self.get_unet_model()
unet.set_attention_slice("auto")
@@ -1244,15 +1183,15 @@ def test_set_attention_slice_auto(self):
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 5 * 10**9
- @require_torch_gpu
+ @require_torch_accelerator
def test_set_attention_slice_max(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
unet = self.get_unet_model()
unet.set_attention_slice("max")
@@ -1264,15 +1203,15 @@ def test_set_attention_slice_max(self):
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 5 * 10**9
- @require_torch_gpu
+ @require_torch_accelerator
def test_set_attention_slice_int(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
unet = self.get_unet_model()
unet.set_attention_slice(2)
@@ -1284,15 +1223,15 @@ def test_set_attention_slice_int(self):
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 5 * 10**9
- @require_torch_gpu
+ @require_torch_accelerator
def test_set_attention_slice_list(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
# there are 32 sliceable layers
slice_list = 16 * [2, 3]
@@ -1306,7 +1245,7 @@ def test_set_attention_slice_list(self):
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 5 * 10**9
diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py
index 6f3662e01750..9431e810280f 100644
--- a/tests/models/unets/test_models_unet_controlnetxs.py
+++ b/tests/models/unets/test_models_unet_controlnetxs.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
import unittest
import numpy as np
@@ -269,37 +268,14 @@ def assert_unfrozen(module):
assert_unfrozen(u.ctrl_to_base)
def test_gradient_checkpointing_is_applied(self):
- model_class_copy = copy.copy(UNetControlNetXSModel)
-
- modules_with_gc_enabled = {}
-
- # now monkey patch the following function:
- # def _set_gradient_checkpointing(self, module, value=False):
- # if hasattr(module, "gradient_checkpointing"):
- # module.gradient_checkpointing = value
-
- def _set_gradient_checkpointing_new(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
- modules_with_gc_enabled[module.__class__.__name__] = True
-
- model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
-
- init_dict, _ = self.prepare_init_args_and_inputs_for_common()
- model = model_class_copy(**init_dict)
-
- model.enable_gradient_checkpointing()
-
- EXPECTED_SET = {
+ expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
-
- assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
- assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@is_flaky
def test_forward_no_control(self):
@@ -344,6 +320,7 @@ def test_time_embedding_mixing(self):
assert output.shape == output_mix_time.shape
+ @unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
pass
diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py
index ee05f0d93824..209806a5fe26 100644
--- a/tests/models/unets/test_models_unet_motion.py
+++ b/tests/models/unets/test_models_unet_motion.py
@@ -161,27 +161,7 @@ def test_xformers_enable_works(self):
), "xformers is not enabled"
def test_gradient_checkpointing_is_applied(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model_class_copy = copy.copy(self.model_class)
-
- modules_with_gc_enabled = {}
-
- # now monkey patch the following function:
- # def _set_gradient_checkpointing(self, module, value=False):
- # if hasattr(module, "gradient_checkpointing"):
- # module.gradient_checkpointing = value
-
- def _set_gradient_checkpointing_new(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
- modules_with_gc_enabled[module.__class__.__name__] = True
-
- model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
-
- model = model_class_copy(**init_dict)
- model.enable_gradient_checkpointing()
-
- EXPECTED_SET = {
+ expected_set = {
"CrossAttnUpBlockMotion",
"CrossAttnDownBlockMotion",
"UNetMidBlockCrossAttnMotion",
@@ -189,9 +169,7 @@ def _set_gradient_checkpointing_new(self, module, value=False):
"Transformer2DModel",
"DownBlockMotion",
}
-
- assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
- assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py
index afdd3d127702..0d7dc823b026 100644
--- a/tests/models/unets/test_models_unet_spatiotemporal.py
+++ b/tests/models/unets/test_models_unet_spatiotemporal.py
@@ -25,7 +25,6 @@
enable_full_determinism,
floats_tensor,
skip_mps,
- torch_all_close,
torch_device,
)
@@ -160,47 +159,6 @@ def test_xformers_enable_works(self):
== "XFormersAttnProcessor"
), "xformers is not enabled"
- @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
- def test_gradient_checkpointing(self):
- # enable deterministic behavior for gradient checkpointing
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
- model.to(torch_device)
-
- assert not model.is_gradient_checkpointing and model.training
-
- out = model(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model.zero_grad()
-
- labels = torch.randn_like(out)
- loss = (out - labels).mean()
- loss.backward()
-
- # re-instantiate the model now enabling gradient checkpointing
- model_2 = self.model_class(**init_dict)
- # clone model
- model_2.load_state_dict(model.state_dict())
- model_2.to(torch_device)
- model_2.enable_gradient_checkpointing()
-
- assert model_2.is_gradient_checkpointing and model_2.training
-
- out_2 = model_2(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model_2.zero_grad()
- loss_2 = (out_2 - labels).mean()
- loss_2.backward()
-
- # compare the output and parameters gradients
- self.assertTrue((loss - loss_2).abs() < 1e-5)
- named_params = dict(model.named_parameters())
- named_params_2 = dict(model_2.named_parameters())
- for name, param in named_params.items():
- self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
-
def test_model_with_num_attention_heads_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -239,30 +197,7 @@ def test_model_with_cross_attention_dim_tuple(self):
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_gradient_checkpointing_is_applied(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- init_dict["num_attention_heads"] = (8, 16)
-
- model_class_copy = copy.copy(self.model_class)
-
- modules_with_gc_enabled = {}
-
- # now monkey patch the following function:
- # def _set_gradient_checkpointing(self, module, value=False):
- # if hasattr(module, "gradient_checkpointing"):
- # module.gradient_checkpointing = value
-
- def _set_gradient_checkpointing_new(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
- modules_with_gc_enabled[module.__class__.__name__] = True
-
- model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
-
- model = model_class_copy(**init_dict)
- model.enable_gradient_checkpointing()
-
- EXPECTED_SET = {
+ expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
@@ -270,9 +205,10 @@ def _set_gradient_checkpointing_new(self, module, value=False):
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
-
- assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
- assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+ num_attention_heads = (8, 16)
+ super().test_gradient_checkpointing_is_applied(
+ expected_set=expected_set, num_attention_heads=num_attention_heads
+ )
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
diff --git a/tests/others/test_check_support_list.py b/tests/others/test_check_support_list.py
new file mode 100644
index 000000000000..0f6b134aad49
--- /dev/null
+++ b/tests/others/test_check_support_list.py
@@ -0,0 +1,68 @@
+import os
+import sys
+import unittest
+from unittest.mock import mock_open, patch
+
+
+git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
+sys.path.append(os.path.join(git_repo_path, "utils"))
+
+from check_support_list import check_documentation # noqa: E402
+
+
+class TestCheckSupportList(unittest.TestCase):
+ def setUp(self):
+ # Mock doc and source contents that we can reuse
+ self.doc_content = """# Documentation
+## FooProcessor
+
+[[autodoc]] module.FooProcessor
+
+## BarProcessor
+
+[[autodoc]] module.BarProcessor
+"""
+ self.source_content = """
+class FooProcessor(nn.Module):
+ pass
+
+class BarProcessor(nn.Module):
+ pass
+"""
+
+ def test_check_documentation_all_documented(self):
+ # In this test, both FooProcessor and BarProcessor are documented
+ with patch("builtins.open", mock_open(read_data=self.doc_content)) as doc_file:
+ doc_file.side_effect = [
+ mock_open(read_data=self.doc_content).return_value,
+ mock_open(read_data=self.source_content).return_value,
+ ]
+
+ undocumented = check_documentation(
+ doc_path="fake_doc.md",
+ src_path="fake_source.py",
+ doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
+ src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
+ )
+ self.assertEqual(len(undocumented), 0, f"Expected no undocumented classes, got {undocumented}")
+
+ def test_check_documentation_missing_class(self):
+ # In this test, only FooProcessor is documented, but BarProcessor is missing from the docs
+ doc_content_missing = """# Documentation
+## FooProcessor
+
+[[autodoc]] module.FooProcessor
+"""
+ with patch("builtins.open", mock_open(read_data=doc_content_missing)) as doc_file:
+ doc_file.side_effect = [
+ mock_open(read_data=doc_content_missing).return_value,
+ mock_open(read_data=self.source_content).return_value,
+ ]
+
+ undocumented = check_documentation(
+ doc_path="fake_doc.md",
+ src_path="fake_source.py",
+ doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
+ src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
+ )
+ self.assertIn("BarProcessor", undocumented, f"BarProcessor should be undocumented, got {undocumented}")
diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py
index 5bed42b8488f..7cf8f30ecc44 100644
--- a/tests/others/test_ema.py
+++ b/tests/others/test_ema.py
@@ -59,6 +59,26 @@ def simulate_backprop(self, unet):
unet.load_state_dict(updated_state_dict)
return unet
+ def test_from_pretrained(self):
+ # Save the model parameters to a temporary directory
+ unet, ema_unet = self.get_models()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ema_unet.save_pretrained(tmpdir)
+
+ # Load the EMA model from the saved directory
+ loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
+ loaded_ema_unet.to(torch_device)
+
+ # Check that the shadow parameters of the loaded model match the original EMA model
+ for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
+ assert torch.allclose(original_param, loaded_param, atol=1e-4)
+
+ # Verify that the optimization step is also preserved
+ assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
+
+ # Check the decay value
+ assert loaded_ema_unet.decay == ema_unet.decay
+
def test_optimization_steps_updated(self):
unet, ema_unet = self.get_models()
# Take the first (hypothetical) EMA step.
@@ -194,6 +214,26 @@ def simulate_backprop(self, unet):
unet.load_state_dict(updated_state_dict)
return unet
+ def test_from_pretrained(self):
+ # Save the model parameters to a temporary directory
+ unet, ema_unet = self.get_models()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ ema_unet.save_pretrained(tmpdir)
+
+ # Load the EMA model from the saved directory
+ loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
+ loaded_ema_unet.to(torch_device)
+
+ # Check that the shadow parameters of the loaded model match the original EMA model
+ for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
+ assert torch.allclose(original_param, loaded_param, atol=1e-4)
+
+ # Verify that the optimization step is also preserved
+ assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
+
+ # Check the decay value
+ assert loaded_ema_unet.decay == ema_unet.decay
+
def test_optimization_steps_updated(self):
unet, ema_unet = self.get_models()
# Take the first (hypothetical) EMA step.
diff --git a/tests/pipelines/hunyuan_dit/__init__.py b/tests/pipelines/allegro/__init__.py
similarity index 100%
rename from tests/pipelines/hunyuan_dit/__init__.py
rename to tests/pipelines/allegro/__init__.py
diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py
new file mode 100644
index 000000000000..30fdd68cfd36
--- /dev/null
+++ b/tests/pipelines/allegro/test_allegro.py
@@ -0,0 +1,372 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import inspect
+import os
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5Config, T5EncoderModel
+
+from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ numpy_cosine_similarity_distance,
+ require_hf_hub_version_greater,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class AllegroPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase):
+ pipeline_class = AllegroPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self, num_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = AllegroTransformer3DModel(
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=4,
+ out_channels=4,
+ num_layers=num_layers,
+ cross_attention_dim=24,
+ sample_width=8,
+ sample_height=8,
+ sample_frames=8,
+ caption_channels=24,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLAllegro(
+ in_channels=3,
+ out_channels=3,
+ down_block_types=(
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ "AllegroDownBlock3D",
+ ),
+ up_block_types=(
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ "AllegroUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ latent_channels=4,
+ layers_per_block=1,
+ norm_num_groups=2,
+ temporal_compression_ratio=4,
+ )
+
+ # TODO(aryan): Only for now, since VAE decoding without tiling is not yet implemented here
+ vae.enable_tiling()
+
+ torch.manual_seed(0)
+ scheduler = DDIMScheduler()
+
+ text_encoder_config = T5Config(
+ **{
+ "d_ff": 37,
+ "d_kv": 8,
+ "d_model": 24,
+ "num_decoder_layers": 2,
+ "num_heads": 4,
+ "num_layers": 2,
+ "relative_attention_num_buckets": 8,
+ "vocab_size": 1103,
+ }
+ )
+ text_encoder = T5EncoderModel(text_encoder_config)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 8,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ @unittest.skip("Decoding without tiling is not yet implemented")
+ def test_save_load_local(self):
+ pass
+
+ @unittest.skip("Decoding without tiling is not yet implemented")
+ def test_save_load_optional_components(self):
+ pass
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (8, 3, 16, 16))
+ expected_video = torch.randn(8, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ # TODO(aryan)
+ @unittest.skip("Decoding without tiling is not yet implemented.")
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_overlap_factor_height=1 / 12,
+ tile_overlap_factor_width=1 / 12,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_save_load_dduf(self):
+ # reimplement because it needs `enable_tiling()` on the loaded pipe.
+ from huggingface_hub import export_folder_as_dduf
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device="cpu")
+ inputs.pop("generator")
+ inputs["generator"] = torch.manual_seed(0)
+
+ pipeline_out = pipe(**inputs)[0].cpu()
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
+ pipe.save_pretrained(tmpdir, safe_serialization=True)
+ export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
+
+ loaded_pipe.vae.enable_tiling()
+ inputs["generator"] = torch.manual_seed(0)
+ loaded_pipeline_out = loaded_pipe(**inputs)[0].cpu()
+
+ assert np.allclose(pipeline_out, loaded_pipeline_out)
+
+
+@slow
+@require_torch_accelerator
+class AllegroPipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_allegro(self):
+ generator = torch.Generator("cpu").manual_seed(0)
+
+ pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16)
+ pipe.enable_model_cpu_offload(device=torch_device)
+ prompt = self.prompt
+
+ videos = pipe(
+ prompt=prompt,
+ height=720,
+ width=1280,
+ num_frames=88,
+ generator=generator,
+ num_inference_steps=2,
+ output_type="pt",
+ ).frames
+
+ video = videos[0]
+ expected_video = torch.randn(1, 88, 720, 1280, 3).numpy()
+
+ max_diff = numpy_cosine_similarity_distance(video, expected_video)
+ assert max_diff < 1e-3, f"Max diff is too high. got {video}"
diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py
index 32f3e13ad911..a0fbc5df1c28 100644
--- a/tests/pipelines/amused/test_amused.py
+++ b/tests/pipelines/amused/test_amused.py
@@ -22,7 +22,7 @@
from diffusers import AmusedPipeline, AmusedScheduler, UVit2DModel, VQModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -38,6 +38,8 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = AmusedPipeline
params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -129,7 +131,7 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class AmusedPipelineSlowTests(unittest.TestCase):
def test_amused_256(self):
pipe = AmusedPipeline.from_pretrained("amused/amused-256")
diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py
index c647a5aa304e..2699bbe7f56f 100644
--- a/tests/pipelines/amused/test_amused_img2img.py
+++ b/tests/pipelines/amused/test_amused_img2img.py
@@ -23,7 +23,7 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -131,7 +131,7 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
def test_amused_256(self):
pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256")
diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py
index 4a8d501450bb..645379a7eab1 100644
--- a/tests/pipelines/amused/test_amused_inpaint.py
+++ b/tests/pipelines/amused/test_amused_inpaint.py
@@ -23,7 +23,7 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -135,7 +135,7 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class AmusedInpaintPipelineSlowTests(unittest.TestCase):
def test_amused_256(self):
pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256")
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
index 54c83d6a1b68..4088d46df5b2 100644
--- a/tests/pipelines/animatediff/test_animatediff.py
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -19,7 +19,14 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ numpy_cosine_similarity_distance,
+ require_accelerator,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
@@ -53,6 +60,8 @@ class AnimateDiffPipelineFastTests(
"callback_on_step_end_tensor_inputs",
]
)
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
cross_attention_dim = 8
@@ -272,7 +281,7 @@ def test_inference_batch_single_identical(
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -288,14 +297,14 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -539,21 +548,29 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class AnimateDiffPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_animatediff(self):
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
@@ -567,7 +584,7 @@ def test_animatediff(self):
clip_sample=False,
)
pipe.enable_vae_slicing()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py
index 519d848c6dc2..7bde663b111e 100644
--- a/tests/pipelines/animatediff/test_animatediff_controlnet.py
+++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py
@@ -21,7 +21,7 @@
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import torch_device
+from diffusers.utils.testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
@@ -281,7 +281,7 @@ def test_inference_batch_single_identical(
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -297,14 +297,14 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -517,3 +517,11 @@ def test_vae_slicing(self, video_count=2):
output_2 = pipe(**inputs)
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py
index 2db0139154e9..f9686ec005f7 100644
--- a/tests/pipelines/animatediff/test_animatediff_sdxl.py
+++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py
@@ -14,14 +14,13 @@
UNetMotionModel,
)
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import torch_device
+from diffusers.utils.testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -36,7 +35,6 @@ class AnimateDiffPipelineSDXLFastTests(
IPAdapterTesterMixin,
SDFunctionTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = AnimateDiffSDXLPipeline
@@ -212,7 +210,7 @@ def test_inference_batch_single_identical(
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -228,14 +226,14 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -250,33 +248,6 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
- def test_prompt_embeds(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = pipe.encode_prompt(prompt)
-
- pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- )
-
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
@@ -305,3 +276,11 @@ def test_xformers_attention_forwardGenerator_pass(self):
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
+
+ @unittest.skip("Test currently not supported.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("Functionality is tested elsewhere.")
+ def test_save_load_optional_components(self):
+ pass
diff --git a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
index 189d6765de4f..3e33326c8a87 100644
--- a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
+++ b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
@@ -20,7 +20,7 @@
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import torch_device
+from diffusers.utils.testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
@@ -345,7 +345,7 @@ def test_inference_batch_single_identical_use_simplified_condition_embedding_tru
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -361,13 +361,13 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
@@ -484,3 +484,11 @@ def test_free_init_with_schedulers(self):
def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py
index c3fd4c73736a..bc771e148eb2 100644
--- a/tests/pipelines/animatediff/test_animatediff_video2video.py
+++ b/tests/pipelines/animatediff/test_animatediff_video2video.py
@@ -19,7 +19,7 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import torch_device
+from diffusers.utils.testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
@@ -258,7 +258,7 @@ def test_inference_batch_single_identical(
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -274,14 +274,14 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -544,3 +544,11 @@ def test_free_noise_multi_prompt(self):
inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
index 5e598e67ec11..3babbbe4ba11 100644
--- a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
+++ b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
@@ -20,7 +20,7 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import torch_device
+from diffusers.utils.testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
@@ -274,7 +274,7 @@ def test_inference_batch_single_identical(
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -290,13 +290,13 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
@@ -533,3 +533,11 @@ def test_free_noise_multi_prompt(self):
inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py
index eddab54a3c03..aaf44985aafd 100644
--- a/tests/pipelines/audioldm/test_audioldm.py
+++ b/tests/pipelines/audioldm/test_audioldm.py
@@ -63,6 +63,8 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py
index fb550dd3219d..66052392f07f 100644
--- a/tests/pipelines/audioldm2/test_audioldm2.py
+++ b/tests/pipelines/audioldm2/test_audioldm2.py
@@ -70,6 +70,8 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = AudioLDM2UNet2DConditionModel(
@@ -469,8 +471,8 @@ def test_xformers_attention_forwardGenerator_pass(self):
pass
def test_dict_tuple_outputs_equivalent(self):
- # increase tolerance from 1e-4 -> 2e-4 to account for large composite model
- super().test_dict_tuple_outputs_equivalent(expected_max_difference=2e-4)
+ # increase tolerance from 1e-4 -> 3e-4 to account for large composite model
+ super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-4)
def test_inference_batch_single_identical(self):
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
@@ -506,9 +508,14 @@ def test_to_dtype(self):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
+ @unittest.skip("Test not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass
+ @unittest.skip("Test not supported for now because of the use of `projection_model` in `encode_prompt()`.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
index 14bc588df905..c56aeb905ac3 100644
--- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
+++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
@@ -5,9 +5,6 @@
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
-from diffusers.utils.testing_utils import (
- torch_device,
-)
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -30,6 +27,8 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -88,37 +87,6 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
- def test_aura_flow_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- do_classifier_free_guidance = inputs["guidance_scale"] > 1
- (
- prompt_embeds,
- prompt_attention_mask,
- negative_prompt_embeds,
- negative_prompt_attention_mask,
- ) = pipe.encode_prompt(
- prompt,
- do_classifier_free_guidance=do_classifier_free_guidance,
- device=torch_device,
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- prompt_attention_mask=prompt_attention_mask,
- negative_prompt_embeds=negative_prompt_embeds,
- negative_prompt_attention_mask=negative_prompt_attention_mask,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
def test_attention_slicing_forward_pass(self):
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
# blocks interfere with each other.
diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
index 7e85cef65129..e073f55aec9e 100644
--- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py
+++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
@@ -60,6 +60,8 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"prompt_reps",
]
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
@@ -196,3 +198,7 @@ def test_blipdiffusion(self):
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
+
+ @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index 884ddfb2a95a..388dc9ef7ec4 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -24,14 +24,16 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
to_np,
@@ -41,7 +43,9 @@
enable_full_determinism()
-class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+class CogVideoXPipelineFastTests(
+ PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -58,8 +62,10 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
- def get_dummy_components(self):
+ def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel(
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
@@ -71,7 +77,7 @@ def get_dummy_components(self):
out_channels=4,
time_embed_dim=2,
text_embed_dim=32, # Must match with tiny-random-t5
- num_layers=1,
+ num_layers=num_layers,
sample_width=2, # latent width: 2 -> final width: 16
sample_height=2, # latent height: 2 -> final height: 16
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
@@ -321,7 +327,7 @@ def test_fused_qkv_projections(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class CogVideoXPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
@@ -339,7 +345,7 @@ def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
videos = pipe(
diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
index 2a51fc65798c..2e962bd247b9 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
@@ -55,6 +55,8 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas
]
)
test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
index f7e1fe7fd6c7..cac47f1a83d4 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
@@ -24,9 +24,10 @@
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -344,25 +345,25 @@ def test_fused_qkv_projections(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
image = load_image(
diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py
index 8d56552ba5ee..79dffd230a75 100644
--- a/tests/pipelines/cogview3/test_cogview3plus.py
+++ b/tests/pipelines/cogview3/test_cogview3plus.py
@@ -24,7 +24,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -56,6 +56,8 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -230,9 +232,12 @@ def test_attention_slicing_forward_pass(
"Attention slicing should not affect the inference results",
)
+ def test_encode_prompt_works_in_isolation(self):
+ return super().test_encode_prompt_works_in_isolation(atol=1e-3, rtol=1e-3)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
@@ -250,7 +255,7 @@ def test_cogview3plus(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b", torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
images = pipe(
diff --git a/tests/pipelines/cogview4/__init__.py b/tests/pipelines/cogview4/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/cogview4/test_cogview4.py b/tests/pipelines/cogview4/test_cogview4.py
new file mode 100644
index 000000000000..2a97a0799d76
--- /dev/null
+++ b/tests/pipelines/cogview4/test_cogview4.py
@@ -0,0 +1,234 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, GlmConfig, GlmForCausalLM
+
+from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class CogView4PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = CogView4Pipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = CogView4Transformer2DModel(
+ patch_size=2,
+ in_channels=4,
+ num_layers=2,
+ attention_head_dim=4,
+ num_attention_heads=4,
+ out_channels=4,
+ text_embed_dim=32,
+ time_embed_dim=8,
+ condition_dim=4,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(
+ base_shift=0.25,
+ max_shift=0.75,
+ base_image_seq_len=256,
+ use_dynamic_shifting=True,
+ time_shift_type="linear",
+ )
+
+ torch.manual_seed(0)
+ text_encoder_config = GlmConfig(
+ hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
+ )
+ text_encoder = GlmForCausalLM(text_encoder_config)
+ # TODO(aryan): change this to THUDM/CogView4 once released
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+ expected_image = torch.randn(3, 16, 16)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
diff --git a/tests/pipelines/consisid/__init__.py b/tests/pipelines/consisid/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py
new file mode 100644
index 000000000000..a39c17bb4f79
--- /dev/null
+++ b/tests/pipelines/consisid/test_consisid.py
@@ -0,0 +1,361 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler
+from diffusers.utils import load_image
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ to_np,
+)
+
+
+enable_full_determinism()
+
+
+class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = ConsisIDPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = ConsisIDTransformer3DModel(
+ num_attention_heads=2,
+ attention_head_dim=16,
+ in_channels=8,
+ out_channels=4,
+ time_embed_dim=2,
+ text_embed_dim=32,
+ num_layers=1,
+ sample_width=2,
+ sample_height=2,
+ sample_frames=9,
+ patch_size=2,
+ temporal_compression_ratio=4,
+ max_text_seq_length=16,
+ use_rotary_positional_embeddings=True,
+ use_learned_positional_embeddings=True,
+ cross_attn_interval=1,
+ is_kps=False,
+ is_train_face=True,
+ cross_attn_dim_head=1,
+ cross_attn_num_heads=1,
+ LFE_id_dim=2,
+ LFE_vit_dim=2,
+ LFE_depth=5,
+ LFE_dim_head=8,
+ LFE_num_heads=2,
+ LFE_num_id_token=1,
+ LFE_num_querie=1,
+ LFE_output_dim=21,
+ LFE_ff_mult=1,
+ LFE_num_scale=1,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLCogVideoX(
+ in_channels=3,
+ out_channels=3,
+ down_block_types=(
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ ),
+ up_block_types=(
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ latent_channels=4,
+ layers_per_block=1,
+ norm_num_groups=2,
+ temporal_compression_ratio=4,
+ )
+
+ torch.manual_seed(0)
+ scheduler = DDIMScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ id_vit_hidden = [torch.ones([1, 2, 2])] * 1
+ id_cond = torch.ones(1, 2)
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": image_height,
+ "width": image_width,
+ "num_frames": 8,
+ "max_sequence_length": 16,
+ "id_vit_hidden": id_vit_hidden,
+ "id_cond": id_cond,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (8, 3, 16, 16))
+ expected_video = torch.randn(8, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.4):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ # The reason to modify it this way is because ConsisID Transformer limits the generation to resolutions used during initalization.
+ # This limitation comes from using learned positional embeddings which cannot be generated on-the-fly like sincos or RoPE embeddings.
+ # See the if-statement on "self.use_learned_positional_embeddings" in diffusers/models/embeddings.py
+ components["transformer"] = ConsisIDTransformer3DModel.from_config(
+ components["transformer"].config,
+ sample_height=16,
+ sample_width=16,
+ )
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_overlap_factor_height=1 / 12,
+ tile_overlap_factor_width=1 / 12,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+
+@slow
+@require_torch_gpu
+class ConsisIDPipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_consisid(self):
+ generator = torch.Generator("cpu").manual_seed(0)
+
+ pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
+ pipe.enable_model_cpu_offload()
+
+ prompt = self.prompt
+ image = load_image("https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true")
+ id_vit_hidden = [torch.ones([1, 2, 2])] * 1
+ id_cond = torch.ones(1, 2)
+
+ videos = pipe(
+ image=image,
+ prompt=prompt,
+ height=480,
+ width=720,
+ num_frames=16,
+ id_vit_hidden=id_vit_hidden,
+ id_cond=id_cond,
+ generator=generator,
+ num_inference_steps=1,
+ output_type="pt",
+ ).frames
+
+ video = videos[0]
+ expected_video = torch.randn(1, 16, 480, 720, 3).numpy()
+
+ max_diff = numpy_cosine_similarity_distance(video, expected_video)
+ assert max_diff < 1e-3, f"Max diff is too high. got {video}"
diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py
index b12655d989d4..bb21c9ac8dcb 100644
--- a/tests/pipelines/controlnet/test_controlnet.py
+++ b/tests/pipelines/controlnet/test_controlnet.py
@@ -34,13 +34,17 @@
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
get_python_version,
is_torch_compile,
load_image,
load_numpy,
require_torch_2,
- require_torch_gpu,
+ require_torch_accelerator,
run_test_in_subprocess,
slow,
torch_device,
@@ -75,7 +79,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.to("cuda")
+ pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
@@ -122,6 +126,8 @@ class ControlNetPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
@@ -282,6 +288,13 @@ def test_controlnet_lcm_custom_timesteps(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
class StableDiffusionMultiControlNetPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
@@ -291,6 +304,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -514,6 +529,13 @@ def test_inference_multiple_prompt_input(self):
assert image.shape == (4, 64, 64, 3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
@@ -523,6 +545,8 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -697,19 +721,26 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class ControlNetPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
@@ -717,7 +748,7 @@ def test_canny(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -744,7 +775,7 @@ def test_depth(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -771,7 +802,7 @@ def test_hed(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -798,7 +829,7 @@ def test_mlsd(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -825,7 +856,7 @@ def test_normal(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -852,7 +883,7 @@ def test_openpose(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -879,7 +910,7 @@ def test_scribble(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(5)
@@ -906,7 +937,7 @@ def test_seg(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(5)
@@ -928,9 +959,9 @@ def test_seg(self):
assert np.abs(expected_image - image).max() < 8e-2
def test_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg")
@@ -939,7 +970,7 @@ def test_sequential_cpu_offloading(self):
)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
prompt = "house"
image = load_image(
@@ -953,7 +984,7 @@ def test_sequential_cpu_offloading(self):
output_type="np",
)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 7 GB is allocated
assert mem_bytes < 4 * 10**9
@@ -963,7 +994,7 @@ def test_canny_guess_mode(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -996,7 +1027,7 @@ def test_canny_guess_mode_euler(self):
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -1037,7 +1068,7 @@ def test_v11_shuffle_global_pool_conditions(self):
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -1064,17 +1095,17 @@ def test_v11_shuffle_global_pool_conditions(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionMultiControlNetPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_pose_and_canny(self):
controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
@@ -1085,7 +1116,7 @@ def test_pose_and_canny(self):
safety_checker=None,
controlnet=[controlnet_pose, controlnet_canny],
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
index 99a238caf53a..eedda4e21722 100644
--- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
+++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
@@ -68,6 +68,8 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes
"prompt_reps",
]
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
@@ -220,3 +222,7 @@ def test_blipdiffusion_controlnet(self):
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+
+ @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py
index 7c4ae716b37d..100765ee34cb 100644
--- a/tests/pipelines/controlnet/test_controlnet_img2img.py
+++ b/tests/pipelines/controlnet/test_controlnet_img2img.py
@@ -39,7 +39,7 @@
enable_full_determinism,
floats_tensor,
load_numpy,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -189,6 +189,13 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
class StableDiffusionMultiControlNetPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
@@ -198,6 +205,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -389,9 +398,16 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
@@ -409,7 +425,7 @@ def test_canny(self):
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py
index e49106334c2e..b06590e13cb6 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py
@@ -40,7 +40,7 @@
floats_tensor,
load_numpy,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -176,6 +176,13 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests):
pipeline_class = StableDiffusionControlNetInpaintPipeline
@@ -257,6 +264,8 @@ class MultiControlNetInpaintPipelineFastTests(
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -441,9 +450,16 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
@@ -461,7 +477,7 @@ def test_canny(self):
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"botp/stable-diffusion-v1-5-inpainting", safety_checker=None, controlnet=controlnet
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -507,7 +523,7 @@ def test_inpaint(self):
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(33)
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
index d2c63137c99e..ca05db504485 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
@@ -40,7 +40,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
- require_torch_gpu,
+ require_torch_accelerator,
torch_device,
)
@@ -78,6 +78,8 @@ class ControlNetPipelineSDXLFastTests(
}
)
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -243,7 +245,7 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
- @require_torch_gpu
+ @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -252,12 +254,12 @@ def test_stable_diffusion_xl_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py
index c931391ac4d5..503db2f574e2 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py
@@ -35,9 +35,10 @@
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -54,7 +55,6 @@
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -66,7 +66,6 @@ class StableDiffusionXLControlNetPipelineFastTests(
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLControlNetPipeline
@@ -74,6 +73,8 @@ class StableDiffusionXLControlNetPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
@@ -209,10 +210,11 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+ @unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
+ pass
- @require_torch_gpu
+ @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -221,12 +223,12 @@ def test_stable_diffusion_xl_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -294,45 +296,6 @@ def test_stable_diffusion_xl_multi_prompts(self):
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
- # Copied from test_stable_diffusion_xl.py
- def test_stable_diffusion_xl_prompt_embeds(self):
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 2 * [inputs["prompt"]]
- inputs["num_images_per_prompt"] = 2
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 2 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = sd_pipe.encode_prompt(prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
def test_controlnet_sdxl_guess(self):
device = "cpu"
@@ -480,13 +443,15 @@ def new_step(self, *args, **kwargs):
class StableDiffusionXLMultiControlNetPipelineFastTests(
- PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
+ PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -680,18 +645,21 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+ @unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self):
- return self._test_save_load_optional_components()
+ pass
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
- PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
+ PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -855,6 +823,10 @@ def test_control_guidance_switch(self):
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
+
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
@@ -865,9 +837,6 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
def test_negative_conditions(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -889,17 +858,17 @@ def test_negative_conditions(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
@@ -907,7 +876,7 @@ def test_canny(self):
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -930,7 +899,7 @@ def test_depth(self):
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -1019,7 +988,7 @@ def test_conditioning_channels(self):
)
controlnet = ControlNetModel.from_unet(unet, conditioning_channels=4)
- assert type(controlnet.mid_block) == UNetMidBlock2D
+ assert type(controlnet.mid_block) is UNetMidBlock2D
assert controlnet.conditioning_channels == 4
def get_dummy_components(self, time_cond_proj_dim=None):
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
index 6a5976bd0dda..bf5da16fcbb8 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
@@ -28,7 +28,12 @@
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, torch_device
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ require_torch_accelerator,
+ torch_device,
+)
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
@@ -241,7 +246,7 @@ def test_inference_batch_single_identical(self):
def test_save_load_optional_components(self):
pass
- @require_torch_gpu
+ @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -250,12 +255,12 @@ def test_stable_diffusion_xl_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -322,42 +327,3 @@ def test_stable_diffusion_xl_multi_prompts(self):
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
-
- # Copied from test_stable_diffusion_xl.py
- def test_stable_diffusion_xl_prompt_embeds(self):
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 2 * [inputs["prompt"]]
- inputs["num_images_per_prompt"] = 2
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 2 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = sd_pipe.encode_prompt(prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
diff --git a/tests/pipelines/controlnet/test_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py
index bf5564e810ef..c71116dc7927 100644
--- a/tests/pipelines/controlnet/test_flax_controlnet.py
+++ b/tests/pipelines/controlnet/test_flax_controlnet.py
@@ -78,7 +78,7 @@ def test_canny(self):
expected_slice = jnp.array(
[0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]
)
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
def test_pose(self):
@@ -123,5 +123,5 @@ def test_pose(self):
expected_slice = jnp.array(
[[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]
)
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index d2db28bdda35..9a270c2bbf07 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -17,7 +17,9 @@
import unittest
import numpy as np
+import pytest
import torch
+from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from diffusers import (
@@ -29,24 +31,28 @@
from diffusers.models import FluxControlNetModel
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
- slow,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_big_gpu_with_torch_cuda,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
-from ..test_pipelines_common import PipelineTesterMixin
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
enable_full_determinism()
-class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
pipeline_class = FluxControlNetPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -123,6 +129,8 @@ def get_dummy_components(self):
"transformer": transformer,
"vae": vae,
"controlnet": controlnet,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -167,7 +175,7 @@ def test_controlnet_flux(self):
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
- [0.7348633, 0.41333008, 0.6621094, 0.5444336, 0.47607422, 0.5859375, 0.44677734, 0.4506836, 0.40454102]
+ [0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
)
assert (
@@ -178,56 +186,92 @@ def test_controlnet_flux(self):
def test_xformers_attention_forwardGenerator_pass(self):
pass
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
-@slow
-@require_torch_gpu
+ height_width_pairs = [(32, 32), (72, 56)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update(
+ {
+ "control_image": randn_tensor(
+ (1, 3, height, width),
+ device=torch_device,
+ dtype=torch.float16,
+ )
+ }
+ )
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+
+@nightly
+@require_big_gpu_with_torch_cuda
+@pytest.mark.big_gpu_with_torch_cuda
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = FluxControlNetModel.from_pretrained(
"InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16
)
pipe = FluxControlNetPipeline.from_pretrained(
- "black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16
- )
- pipe.enable_model_cpu_offload()
+ "black-forest-labs/FLUX.1-dev",
+ text_encoder=None,
+ text_encoder_2=None,
+ controlnet=controlnet,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "A girl in city, 25 years old, cool, futuristic"
control_image = load_image(
"https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
- )
+ ).resize((512, 512))
+
+ prompt_embeds = torch.load(
+ hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
+ ).to(torch_device)
+ pooled_prompt_embeds = torch.load(
+ hf_hub_download(
+ repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
+ )
+ ).to(torch_device)
output = pipe(
- prompt,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=2,
guidance_scale=3.5,
+ max_sequence_length=256,
output_type="np",
+ height=512,
+ width=512,
generator=generator,
)
image = output.images[0]
- assert image.shape == (1024, 1024, 3)
+ assert image.shape == (512, 512, 3)
original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array(
- [0.33007812, 0.33984375, 0.33984375, 0.328125, 0.34179688, 0.33984375, 0.30859375, 0.3203125, 0.3203125]
- )
+ expected_image = np.array([0.2734, 0.2852, 0.2852, 0.2734, 0.2754, 0.2891, 0.2617, 0.2637, 0.2773])
- assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
+ assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
index 9c0e948861f7..59ccb9237819 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
@@ -1,4 +1,3 @@
-import gc
import unittest
import numpy as np
@@ -13,11 +12,9 @@
FluxTransformer2DModel,
)
from diffusers.utils.testing_utils import (
- numpy_cosine_similarity_distance,
- require_torch_gpu,
- slow,
torch_device,
)
+from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -161,30 +158,6 @@ def test_flux_controlnet_different_prompts(self):
assert max_diff > 1e-6
- def test_flux_controlnet_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- device=torch_device,
- max_sequence_length=inputs["max_sequence_length"],
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -223,69 +196,30 @@ def test_fused_qkv_projections(self):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
-@slow
-@require_torch_gpu
-class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
- pipeline_class = FluxControlNetImg2ImgPipeline
- repo_id = "black-forest-labs/FLUX.1-schnell"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device="cpu").manual_seed(seed)
-
- image = torch.randn(1, 3, 64, 64).to(device)
- control_image = torch.randn(1, 3, 64, 64).to(device)
-
- return {
- "prompt": "A photo of a cat",
- "image": image,
- "control_image": control_image,
- "num_inference_steps": 2,
- "guidance_scale": 5.0,
- "controlnet_conditioning_scale": 1.0,
- "strength": 0.8,
- "output_type": "np",
- "generator": generator,
- }
-
- @unittest.skip("We cannot run inference on this model with the current CI hardware")
- def test_flux_controlnet_img2img_inference(self):
- pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
- pipe.enable_model_cpu_offload()
-
- inputs = self.get_inputs(torch_device)
-
- image = pipe(**inputs).images[0]
- image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- [0.36132812, 0.30004883, 0.25830078],
- [0.36669922, 0.31103516, 0.23754883],
- [0.34814453, 0.29248047, 0.23583984],
- [0.35791016, 0.30981445, 0.23999023],
- [0.36328125, 0.31274414, 0.2607422],
- [0.37304688, 0.32177734, 0.26171875],
- [0.3671875, 0.31933594, 0.25756836],
- [0.36035156, 0.31103516, 0.2578125],
- [0.3857422, 0.33789062, 0.27563477],
- [0.3701172, 0.31982422, 0.265625],
- ],
- dtype=np.float32,
- )
-
- max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
-
- assert max_diff < 1e-4
+ height_width_pairs = [(32, 32), (72, 56)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+ inputs.update(
+ {
+ "control_image": randn_tensor(
+ (1, 3, height, width),
+ device=torch_device,
+ dtype=torch.float16,
+ ),
+ "image": randn_tensor(
+ (1, 3, height, width),
+ device=torch_device,
+ dtype=torch.float16,
+ ),
+ "height": height,
+ "width": width,
+ }
+ )
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
index d66eaaf6a76f..94d97e9962b7 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
@@ -23,7 +23,9 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
+ torch_device,
)
+from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import PipelineTesterMixin
@@ -192,3 +194,33 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 56)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update(
+ {
+ "control_image": randn_tensor(
+ (1, 3, height, width),
+ device=torch_device,
+ dtype=torch.float16,
+ ),
+ "image": randn_tensor(
+ (1, 3, height, width),
+ device=torch_device,
+ dtype=torch.float16,
+ ),
+ "mask_image": torch.ones((1, 1, height, width)).to(torch_device),
+ "height": height,
+ "width": width,
+ }
+ )
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
index 5500c7bd1c81..f7b3db05c8af 100644
--- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
+++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
@@ -29,8 +29,9 @@
from diffusers.models import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -56,6 +57,7 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
+ test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -151,9 +153,14 @@ def test_controlnet_hunyuandit(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 16, 16, 3)
- expected_slice = np.array(
- [0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
- )
+ if torch_device == "xpu":
+ expected_slice = np.array(
+ [0.6376953, 0.84375, 0.58691406, 0.48046875, 0.43652344, 0.5517578, 0.54248047, 0.5644531, 0.48217773]
+ )
+ else:
+ expected_slice = np.array(
+ [0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
+ )
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -176,21 +183,27 @@ def test_save_load_optional_components(self):
# TODO(YiYi) need to fix later
pass
+ @unittest.skip(
+ "Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = HunyuanDiTControlNetPipeline
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
@@ -199,7 +212,7 @@ def test_canny(self):
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -238,7 +251,7 @@ def test_pose(self):
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -277,7 +290,7 @@ def test_depth(self):
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -318,7 +331,7 @@ def test_multi_controlnet(self):
pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -343,6 +356,7 @@ def test_multi_controlnet(self):
assert image.shape == (1024, 1024, 3)
original_image = image[-3:, -3:, -1].flatten()
+
expected_image = np.array(
[0.43652344, 0.44018555, 0.4494629, 0.44995117, 0.45654297, 0.44848633, 0.43603516, 0.4404297, 0.42626953]
)
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
index 9a2a0019d68b..2cd57ce56d52 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
@@ -137,6 +137,8 @@ def get_dummy_components(self):
"transformer": transformer,
"vae": vae,
"controlnet": controlnet,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 74cb56e0337a..84ce09acbe1a 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -15,8 +15,10 @@
import gc
import unittest
+from typing import Optional
import numpy as np
+import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -29,8 +31,10 @@
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ numpy_cosine_similarity_distance,
+ require_big_accelerator,
slow,
torch_device,
)
@@ -56,8 +60,12 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
+ test_layerwise_casting = True
+ test_group_offloading = True
- def get_dummy_components(self):
+ def get_dummy_components(
+ self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False
+ ):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
@@ -70,6 +78,8 @@ def get_dummy_components(self):
caption_projection_dim=32,
pooled_projection_dim=64,
out_channels=8,
+ qk_norm=qk_norm,
+ dual_attention_layers=() if not use_dual_attention else (0, 1),
)
torch.manual_seed(0)
@@ -77,14 +87,17 @@ def get_dummy_components(self):
sample_size=32,
patch_size=1,
in_channels=8,
- num_layers=1,
+ num_layers=num_controlnet_layers,
attention_head_dim=8,
num_attention_heads=4,
joint_attention_dim=32,
caption_projection_dim=32,
pooled_projection_dim=64,
out_channels=8,
+ qk_norm=qk_norm,
+ dual_attention_layers=() if not use_dual_attention else (0,),
)
+
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
@@ -140,6 +153,8 @@ def get_dummy_components(self):
"transformer": transformer,
"vae": vae,
"controlnet": controlnet,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -169,8 +184,7 @@ def get_dummy_inputs(self, device, seed=0):
return inputs
- def test_controlnet_sd3(self):
- components = self.get_dummy_components()
+ def run_pipe(self, components, use_sd35=False):
sd_pipe = StableDiffusion3ControlNetPipeline(**components)
sd_pipe = sd_pipe.to(torch_device, dtype=torch.float16)
sd_pipe.set_progress_bar_config(disable=None)
@@ -183,38 +197,50 @@ def test_controlnet_sd3(self):
assert image.shape == (1, 32, 32, 3)
- expected_slice = np.array([0.5767, 0.7100, 0.5981, 0.5674, 0.5952, 0.4102, 0.5093, 0.5044, 0.6030])
+ if not use_sd35:
+ expected_slice = np.array([0.5767, 0.7100, 0.5981, 0.5674, 0.5952, 0.4102, 0.5093, 0.5044, 0.6030])
+ else:
+ expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ def test_controlnet_sd3(self):
+ components = self.get_dummy_components()
+ self.run_pipe(components)
+
+ def test_controlnet_sd35(self):
+ components = self.get_dummy_components(num_controlnet_layers=1, qk_norm="rms_norm", use_dual_attention=True)
+ self.run_pipe(components, use_sd35=True)
+
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
pass
@slow
-@require_torch_gpu
+@require_big_accelerator
+@pytest.mark.big_gpu_with_torch_cuda
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -238,18 +264,16 @@ def test_canny(self):
original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array(
- [0.20947266, 0.1574707, 0.19897461, 0.15063477, 0.1418457, 0.17285156, 0.14160156, 0.13989258, 0.30810547]
- )
+ expected_image = np.array([0.7314, 0.7075, 0.6611, 0.7539, 0.7563, 0.6650, 0.6123, 0.7275, 0.7222])
- assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
+ assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
def test_pose(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Pose", torch_dtype=torch.float16)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -272,19 +296,16 @@ def test_pose(self):
assert image.shape == (1024, 1024, 3)
original_image = image[-3:, -3:, -1].flatten()
+ expected_image = np.array([0.9048, 0.8740, 0.8936, 0.8516, 0.8799, 0.9360, 0.8379, 0.8408, 0.8652])
- expected_image = np.array(
- [0.8671875, 0.86621094, 0.91015625, 0.8491211, 0.87890625, 0.9140625, 0.8300781, 0.8334961, 0.8623047]
- )
-
- assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
+ assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
def test_tile(self):
- controlnet = SD3ControlNetModel.from_pretrained("InstantX//SD3-Controlnet-Tile", torch_dtype=torch.float16)
+ controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Tile", torch_dtype=torch.float16)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -307,12 +328,9 @@ def test_tile(self):
assert image.shape == (1024, 1024, 3)
original_image = image[-3:, -3:, -1].flatten()
+ expected_image = np.array([0.6699, 0.6836, 0.6226, 0.6572, 0.7310, 0.6646, 0.6650, 0.6694, 0.6011])
- expected_image = np.array(
- [0.6982422, 0.7011719, 0.65771484, 0.6904297, 0.7416992, 0.6904297, 0.6977539, 0.7080078, 0.6386719]
- )
-
- assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
+ assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
def test_multi_controlnet(self):
controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16)
@@ -321,7 +339,7 @@ def test_multi_controlnet(self):
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -344,8 +362,6 @@ def test_multi_controlnet(self):
assert image.shape == (1024, 1024, 3)
original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array(
- [0.7451172, 0.7416992, 0.7158203, 0.7792969, 0.7607422, 0.7089844, 0.6855469, 0.71777344, 0.7314453]
- )
+ expected_image = np.array([0.7207, 0.7041, 0.6543, 0.7500, 0.7490, 0.6592, 0.6001, 0.7168, 0.7231])
- assert np.abs(original_image.flatten() - expected_image).max() < 1e-2
+ assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py
index bb0306741fdb..74af4b6775cc 100644
--- a/tests/pipelines/controlnet_xs/test_controlnetxs.py
+++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py
@@ -34,19 +34,21 @@
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
is_torch_compile,
load_image,
load_numpy,
+ require_accelerator,
require_torch_2,
- require_torch_gpu,
+ require_torch_accelerator,
run_test_in_subprocess,
slow,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
-from ...models.autoencoders.test_models_vae import (
+from ...models.autoencoders.vae import (
get_asym_autoencoder_kl_config,
get_autoencoder_kl_config,
get_autoencoder_tiny_config,
@@ -91,7 +93,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
safety_checker=None,
torch_dtype=torch.float16,
)
- pipe.to("cuda")
+ pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
@@ -137,6 +139,8 @@ class ControlNetXSPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_attention_slicing = False
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
@@ -306,7 +310,7 @@ def test_multi_vae(self):
assert out_vae_np.shape == out_np.shape
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -322,23 +326,30 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow
-@require_torch_gpu
+@require_torch_accelerator
class ControlNetXSPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetXSAdapter.from_pretrained(
@@ -347,7 +358,7 @@ def test_canny(self):
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -373,7 +384,7 @@ def test_depth(self):
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
index c940504d6c3e..24a8b9cd5739 100644
--- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
+++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
@@ -31,10 +31,17 @@
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ load_image,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
from diffusers.utils.torch_utils import randn_tensor
-from ...models.autoencoders.test_models_vae import (
+from ...models.autoencoders.vae import (
get_asym_autoencoder_kl_config,
get_autoencoder_kl_config,
get_autoencoder_tiny_config,
@@ -50,7 +57,6 @@
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -61,7 +67,6 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLControlNetXSPipeline
@@ -71,6 +76,8 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_attention_slicing = False
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -192,7 +199,11 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
- @require_torch_gpu
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
+
+ @require_torch_accelerator
# Copied from test_controlnet_sdxl.py
def test_stable_diffusion_xl_offloads(self):
pipes = []
@@ -202,12 +213,12 @@ def test_stable_diffusion_xl_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -276,49 +287,6 @@ def test_stable_diffusion_xl_multi_prompts(self):
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
- # Copied from test_stable_diffusion_xl.py
- def test_stable_diffusion_xl_prompt_embeds(self):
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 2 * [inputs["prompt"]]
- inputs["num_images_per_prompt"] = 2
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 2 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = sd_pipe.encode_prompt(prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4
-
- # Copied from test_stable_diffusion_xl.py
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
# Copied from test_controlnetxs.py
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -369,12 +337,12 @@ def test_multi_vae(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetXSAdapter.from_pretrained(
@@ -383,7 +351,7 @@ def test_canny(self):
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -407,7 +375,7 @@ def test_depth(self):
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py
index 2078a592ceca..f7e0093c515a 100644
--- a/tests/pipelines/ddim/test_ddim.py
+++ b/tests/pipelines/ddim/test_ddim.py
@@ -19,7 +19,7 @@
import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -99,7 +99,7 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py
index f6d0821da4c2..750885db2c23 100644
--- a/tests/pipelines/ddpm/test_ddpm.py
+++ b/tests/pipelines/ddpm/test_ddpm.py
@@ -19,7 +19,7 @@
import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
enable_full_determinism()
@@ -88,7 +88,7 @@ def test_inference_predict_sample(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py
index 0818665ea113..295b29f12e8c 100644
--- a/tests/pipelines/deepfloyd_if/test_if.py
+++ b/tests/pipelines/deepfloyd_if/test_if.py
@@ -23,7 +23,19 @@
)
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, skip_mps, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ load_numpy,
+ require_accelerator,
+ require_hf_hub_version_greater,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ skip_mps,
+ slow,
+ torch_device,
+)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -55,10 +67,8 @@ def get_dummy_inputs(self, device, seed=0):
return inputs
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
super().test_save_load_float16(expected_max_diff=1e-1)
@@ -81,30 +91,39 @@ def test_inference_batch_single_identical(self):
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_save_load_dduf(self):
+ super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
+
+ @unittest.skip("Functionality is tested elsewhere.")
+ def test_save_load_optional_components(self):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class IFPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_if_text_to_image(self):
pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
+ backend_reset_max_memory_allocated(torch_device)
+ backend_empty_cache(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py
index b71cb05e50ae..da06dc355896 100644
--- a/tests/pipelines/deepfloyd_if/test_if_img2img.py
+++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py
@@ -22,7 +22,20 @@
from diffusers import IFImg2ImgPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ floats_tensor,
+ load_numpy,
+ require_accelerator,
+ require_hf_hub_version_greater,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ skip_mps,
+ slow,
+ torch_device,
+)
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -60,9 +73,6 @@ def get_dummy_inputs(self, device, seed=0):
return inputs
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
@@ -70,12 +80,14 @@ def test_save_load_optional_components(self):
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
super().test_save_load_float16(expected_max_diff=1e-1)
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
@@ -90,21 +102,30 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-2,
)
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_save_load_dduf(self):
+ super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
+
+ @unittest.skip("Functionality is tested elsewhere.")
+ def test_save_load_optional_components(self):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class IFImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_if_img2img(self):
pipe = IFImg2ImgPipeline.from_pretrained(
@@ -113,11 +134,11 @@ def test_if_img2img(self):
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
+ backend_reset_max_memory_allocated(torch_device)
+ backend_empty_cache(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
index dc0cf9826b62..77f2f9c7bb64 100644
--- a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
@@ -22,7 +22,21 @@
from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ floats_tensor,
+ load_numpy,
+ require_accelerator,
+ require_hf_hub_version_greater,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ skip_mps,
+ slow,
+ torch_device,
+)
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -69,10 +83,8 @@ def get_dummy_inputs(self, device, seed=0):
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
super().test_save_load_float16(expected_max_diff=1e-1)
@@ -88,21 +100,30 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-2,
)
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_save_load_dduf(self):
+ super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
+
+ @unittest.skip("Functionality is tested elsewhere.")
+ def test_save_load_optional_components(self):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_if_img2img_superresolution(self):
pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
@@ -111,11 +132,11 @@ def test_if_img2img_superresolution(self):
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
+ backend_reset_max_memory_allocated(torch_device)
+ backend_empty_cache(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -135,7 +156,8 @@ def test_if_img2img_superresolution(self):
assert image.shape == (256, 256, 3)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
+
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py
index df0cecd8c307..a62d95725774 100644
--- a/tests/pipelines/deepfloyd_if/test_if_inpainting.py
+++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py
@@ -22,7 +22,21 @@
from diffusers import IFInpaintingPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ floats_tensor,
+ load_numpy,
+ require_accelerator,
+ require_hf_hub_version_greater,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ skip_mps,
+ slow,
+ torch_device,
+)
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
@@ -69,10 +83,8 @@ def get_dummy_inputs(self, device, seed=0):
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
super().test_save_load_float16(expected_max_diff=1e-1)
@@ -88,32 +100,41 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-2,
)
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_save_load_dduf(self):
+ super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
+
+ @unittest.skip("Test done elsewhere.")
+ def test_save_load_optional_components(self, expected_max_difference=0.0001):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class IFInpaintingPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_if_inpainting(self):
pipe = IFInpaintingPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 64, 64), rng=random.Random(1)).to(torch_device)
@@ -129,7 +150,7 @@ def test_if_inpainting(self):
)
image = output.images[0]
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
index 2e9f64773289..f98284bef646 100644
--- a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
@@ -22,7 +22,21 @@
from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ floats_tensor,
+ load_numpy,
+ require_accelerator,
+ require_hf_hub_version_greater,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ skip_mps,
+ slow,
+ torch_device,
+)
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
@@ -71,10 +85,8 @@ def get_dummy_inputs(self, device, seed=0):
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
super().test_save_load_float16(expected_max_diff=1e-1)
@@ -90,33 +102,42 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-2,
)
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_save_load_dduf(self):
+ super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
+
+ @unittest.skip("Test done elsewhere.")
+ def test_save_load_optional_components(self, expected_max_difference=0.0001):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class IFInpaintingSuperResolutionPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_if_inpainting_superresolution(self):
pipe = IFInpaintingSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
# Super resolution test
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -138,7 +159,7 @@ def test_if_inpainting_superresolution(self):
assert image.shape == (256, 256, 3)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py
index 2e3c8c6e0e15..435b0cc6ec07 100644
--- a/tests/pipelines/deepfloyd_if/test_if_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py
@@ -22,7 +22,21 @@
from diffusers import IFSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ floats_tensor,
+ load_numpy,
+ require_accelerator,
+ require_hf_hub_version_greater,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ skip_mps,
+ slow,
+ torch_device,
+)
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -64,10 +78,8 @@ def get_dummy_inputs(self, device, seed=0):
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
super().test_save_load_float16(expected_max_diff=1e-1)
@@ -83,33 +95,42 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-2,
)
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_save_load_dduf(self):
+ super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
+
+ @unittest.skip("Test done elsewhere.")
+ def test_save_load_optional_components(self, expected_max_difference=0.0001):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class IFSuperResolutionPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_if_superresolution(self):
pipe = IFSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
# Super resolution test
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -125,7 +146,7 @@ def test_if_superresolution(self):
assert image.shape == (256, 256, 3)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
diff --git a/tests/pipelines/easyanimate/__init__.py b/tests/pipelines/easyanimate/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/easyanimate/test_easyanimate.py b/tests/pipelines/easyanimate/test_easyanimate.py
new file mode 100644
index 000000000000..13d5c2f49b11
--- /dev/null
+++ b/tests/pipelines/easyanimate/test_easyanimate.py
@@ -0,0 +1,294 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2Tokenizer, Qwen2VLForConditionalGeneration
+
+from diffusers import (
+ AutoencoderKLMagvit,
+ EasyAnimatePipeline,
+ EasyAnimateTransformer3DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = EasyAnimatePipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = EasyAnimateTransformer3DModel(
+ num_attention_heads=2,
+ attention_head_dim=16,
+ in_channels=4,
+ out_channels=4,
+ time_embed_dim=2,
+ text_embed_dim=16, # Must match with tiny-random-t5
+ num_layers=1,
+ sample_width=16, # latent width: 2 -> final width: 16
+ sample_height=16, # latent height: 2 -> final height: 16
+ patch_size=2,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLMagvit(
+ in_channels=3,
+ out_channels=3,
+ down_block_types=(
+ "SpatialDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",
+ ),
+ up_block_types=(
+ "SpatialUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ latent_channels=4,
+ layers_per_block=1,
+ norm_num_groups=2,
+ spatial_group_norm=False,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ text_encoder = Qwen2VLForConditionalGeneration.from_pretrained(
+ "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
+ )
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 5,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (5, 3, 16, 16))
+ expected_video = torch.randn(5, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_dict_tuple_outputs_equivalent(self, expected_slice=None, expected_max_difference=0.001):
+ # Seems to need a higher tolerance
+ return super().test_dict_tuple_outputs_equivalent(expected_slice, expected_max_difference)
+
+ def test_encode_prompt_works_in_isolation(self):
+ # Seems to need a higher tolerance
+ return super().test_encode_prompt_works_in_isolation(atol=1e-3, rtol=1e-3)
+
+
+@slow
+@require_torch_gpu
+class EasyAnimatePipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_EasyAnimate(self):
+ generator = torch.Generator("cpu").manual_seed(0)
+
+ pipe = EasyAnimatePipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", torch_dtype=torch.float16)
+ pipe.enable_model_cpu_offload()
+ prompt = self.prompt
+
+ videos = pipe(
+ prompt=prompt,
+ height=480,
+ width=720,
+ num_frames=5,
+ generator=generator,
+ num_inference_steps=2,
+ output_type="pt",
+ ).frames
+
+ video = videos[0]
+ expected_video = torch.randn(1, 5, 480, 720, 3).numpy()
+
+ max_diff = numpy_cosine_similarity_distance(video, expected_video)
+ assert max_diff < 1e-3, f"Max diff is too high. got {video}"
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index 4caff4030261..6a560367a5b8 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -2,39 +2,68 @@
import unittest
import numpy as np
+import pytest
import torch
+from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
-from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
+from diffusers import (
+ AutoencoderKL,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ FluxPipeline,
+ FluxTransformer2DModel,
+)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ nightly,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_big_accelerator,
slow,
torch_device,
)
from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ FluxIPAdapterTesterMixin,
PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
-class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+class FluxPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
-
- def get_dummy_components(self):
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=4,
- num_layers=1,
- num_single_layers=1,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
@@ -89,6 +118,8 @@ def get_dummy_components(self):
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -125,30 +156,6 @@ def test_flux_different_prompts(self):
# For some reasons, they don't show large differences
assert max_diff > 1e-6
- def test_flux_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- device=torch_device,
- max_sequence_length=inputs["max_sequence_length"],
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -189,9 +196,35 @@ def test_fused_qkv_projections(self):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
-@slow
-@require_torch_gpu
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_flux_true_cfg(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("generator")
+
+ no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ inputs["negative_prompt"] = "bad quality"
+ inputs["true_cfg_scale"] = 2.0
+ true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ assert not np.allclose(no_true_cfg_out, true_cfg_out)
+
+
+@nightly
+@require_big_accelerator
+@pytest.mark.big_gpu_with_torch_cuda
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -199,12 +232,103 @@ class FluxPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_inputs(self, device, seed=0):
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ prompt_embeds = torch.load(
+ hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
+ ).to(torch_device)
+ pooled_prompt_embeds = torch.load(
+ hf_hub_download(
+ repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
+ )
+ ).to(torch_device)
+ return {
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "num_inference_steps": 2,
+ "guidance_scale": 0.0,
+ "max_sequence_length": 256,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ def test_flux_inference(self):
+ pipe = self.pipeline_class.from_pretrained(
+ self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
+ ).to(torch_device)
+
+ inputs = self.get_inputs(torch_device)
+
+ image = pipe(**inputs).images[0]
+ image_slice = image[0, :10, :10]
+ expected_slice = np.array(
+ [
+ 0.3242,
+ 0.3203,
+ 0.3164,
+ 0.3164,
+ 0.3125,
+ 0.3125,
+ 0.3281,
+ 0.3242,
+ 0.3203,
+ 0.3301,
+ 0.3262,
+ 0.3242,
+ 0.3281,
+ 0.3242,
+ 0.3203,
+ 0.3262,
+ 0.3262,
+ 0.3164,
+ 0.3262,
+ 0.3281,
+ 0.3184,
+ 0.3281,
+ 0.3281,
+ 0.3203,
+ 0.3281,
+ 0.3281,
+ 0.3164,
+ 0.3320,
+ 0.3320,
+ 0.3203,
+ ],
+ dtype=np.float32,
+ )
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
+
+ assert max_diff < 1e-4
+
+
+@slow
+@require_big_accelerator
+@pytest.mark.big_gpu_with_torch_cuda
+class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
+ pipeline_class = FluxPipeline
+ repo_id = "black-forest-labs/FLUX.1-dev"
+ image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14"
+ weight_name = "ip_adapter.safetensors"
+ ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
@@ -212,40 +336,84 @@ def get_inputs(self, device, seed=0):
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
+ prompt_embeds = torch.load(
+ hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
+ )
+ pooled_prompt_embeds = torch.load(
+ hf_hub_download(
+ repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
+ )
+ )
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ ip_adapter_image = np.zeros((1024, 1024, 3), dtype=np.uint8)
return {
- "prompt": "A photo of a cat",
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
+ "ip_adapter_image": ip_adapter_image,
"num_inference_steps": 2,
- "guidance_scale": 5.0,
+ "guidance_scale": 3.5,
+ "true_cfg_scale": 4.0,
+ "max_sequence_length": 256,
"output_type": "np",
"generator": generator,
}
- # TODO: Dhruv. Move large model tests to a dedicated runner)
- @unittest.skip("We cannot run inference on this model with the current CI hardware")
- def test_flux_inference(self):
- pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
+ def test_flux_ip_adapter_inference(self):
+ pipe = self.pipeline_class.from_pretrained(
+ self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
+ )
+ pipe.load_ip_adapter(
+ self.ip_adapter_repo_id,
+ weight_name=self.weight_name,
+ image_encoder_pretrained_model_name_or_path=self.image_encoder_pretrained_model_name_or_path,
+ )
+ pipe.set_ip_adapter_scale(1.0)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
+
expected_slice = np.array(
[
- [0.36132812, 0.30004883, 0.25830078],
- [0.36669922, 0.31103516, 0.23754883],
- [0.34814453, 0.29248047, 0.23583984],
- [0.35791016, 0.30981445, 0.23999023],
- [0.36328125, 0.31274414, 0.2607422],
- [0.37304688, 0.32177734, 0.26171875],
- [0.3671875, 0.31933594, 0.25756836],
- [0.36035156, 0.31103516, 0.2578125],
- [0.3857422, 0.33789062, 0.27563477],
- [0.3701172, 0.31982422, 0.265625],
+ 0.1855,
+ 0.1680,
+ 0.1406,
+ 0.1953,
+ 0.1699,
+ 0.1465,
+ 0.2012,
+ 0.1738,
+ 0.1484,
+ 0.2051,
+ 0.1797,
+ 0.1523,
+ 0.2012,
+ 0.1719,
+ 0.1445,
+ 0.2070,
+ 0.1777,
+ 0.1465,
+ 0.2090,
+ 0.1836,
+ 0.1484,
+ 0.2129,
+ 0.1875,
+ 0.1523,
+ 0.2090,
+ 0.1816,
+ 0.1484,
+ 0.2110,
+ 0.1836,
+ 0.1543,
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
- assert max_diff < 1e-4
+ assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py
new file mode 100644
index 000000000000..d8293952adcb
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_control.py
@@ -0,0 +1,181 @@
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
+from diffusers.utils.testing_utils import torch_device
+
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+)
+
+
+class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = FluxControlPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=8,
+ out_channels=4,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ control_image = Image.new("RGB", (16, 16), 0)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "control_image": control_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_flux_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_img2img.py b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py
new file mode 100644
index 000000000000..966543f63aeb
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py
@@ -0,0 +1,144 @@
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxControlImg2ImgPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class FluxControlImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = FluxControlImg2ImgPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=8,
+ out_channels=4,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ image = Image.new("RGB", (16, 16), 0)
+ control_image = Image.new("RGB", (16, 16), 0)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "control_image": control_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "strength": 0.8,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_flux_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
new file mode 100644
index 000000000000..44ce2a4dedfc
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -0,0 +1,175 @@
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxControlInpaintPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.utils.testing_utils import (
+ torch_device,
+)
+
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+)
+
+
+class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = FluxControlInpaintPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=8,
+ out_channels=4,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ image = Image.new("RGB", (8, 8), 0)
+ control_image = Image.new("RGB", (8, 8), 0)
+ mask_image = Image.new("RGB", (8, 8), 255)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "control_image": control_image,
+ "generator": generator,
+ "image": image,
+ "mask_image": mask_image,
+ "strength": 0.8,
+ "num_inference_steps": 2,
+ "guidance_scale": 30.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/flux/test_pipeline_flux_fill.py b/tests/pipelines/flux/test_pipeline_flux_fill.py
new file mode 100644
index 000000000000..04d4c68db8f3
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_fill.py
@@ -0,0 +1,146 @@
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxFillPipeline, FluxTransformer2DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = FluxFillPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=20,
+ out_channels=8,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=2,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ mask_image = torch.ones((1, 1, 32, 32)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_flux_fill_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=1e-3)
diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py
index a038b1725812..6d33ca721b6c 100644
--- a/tests/pipelines/flux/test_pipeline_flux_img2img.py
+++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py
@@ -12,13 +12,13 @@
torch_device,
)
-from ..test_pipelines_common import PipelineTesterMixin
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
enable_full_determinism()
-class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
pipeline_class = FluxImg2ImgPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
@@ -85,6 +85,8 @@ def get_dummy_components(self):
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -124,26 +126,16 @@ def test_flux_different_prompts(self):
# For some reasons, they don't show large differences
assert max_diff > 1e-6
- def test_flux_prompt_embeds(self):
+ def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
- output_with_prompt = pipe(**inputs).images[0]
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- device=torch_device,
- max_sequence_length=inputs["max_sequence_length"],
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py
index ac2eb1fa261b..161348455ca4 100644
--- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py
@@ -12,13 +12,13 @@
torch_device,
)
-from ..test_pipelines_common import PipelineTesterMixin
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
enable_full_determinism()
-class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
pipeline_class = FluxInpaintPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
@@ -85,6 +85,8 @@ def get_dummy_components(self):
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -126,26 +128,16 @@ def test_flux_inpaint_different_prompts(self):
# For some reasons, they don't show large differences
assert max_diff > 1e-6
- def test_flux_inpaint_prompt_embeds(self):
+ def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
- output_with_prompt = pipe(**inputs).images[0]
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- device=torch_device,
- max_sequence_length=inputs["max_sequence_length"],
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py
new file mode 100644
index 000000000000..2cd73a51a173
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_redux.py
@@ -0,0 +1,109 @@
+import gc
+import unittest
+
+import numpy as np
+import pytest
+import torch
+
+from diffusers import FluxPipeline, FluxPriorReduxPipeline
+from diffusers.utils import load_image
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ numpy_cosine_similarity_distance,
+ require_big_accelerator,
+ slow,
+ torch_device,
+)
+
+
+@slow
+@require_big_accelerator
+@pytest.mark.big_gpu_with_torch_cuda
+class FluxReduxSlowTests(unittest.TestCase):
+ pipeline_class = FluxPriorReduxPipeline
+ repo_id = "YiYiXu/yiyi-redux" # update to "black-forest-labs/FLUX.1-Redux-dev" once PR is merged
+ base_pipeline_class = FluxPipeline
+ base_repo_id = "black-forest-labs/FLUX.1-schnell"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_inputs(self, device, seed=0):
+ init_image = load_image(
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
+ )
+ return {"image": init_image}
+
+ def get_base_pipeline_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ return {
+ "num_inference_steps": 2,
+ "guidance_scale": 2.0,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ def test_flux_redux_inference(self):
+ pipe_redux = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
+ pipe_base = self.base_pipeline_class.from_pretrained(
+ self.base_repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
+ )
+ pipe_redux.to(torch_device)
+ pipe_base.enable_model_cpu_offload(device=torch_device)
+
+ inputs = self.get_inputs(torch_device)
+ base_pipeline_inputs = self.get_base_pipeline_inputs(torch_device)
+
+ redux_pipeline_output = pipe_redux(**inputs)
+ image = pipe_base(**base_pipeline_inputs, **redux_pipeline_output).images[0]
+
+ image_slice = image[0, :10, :10]
+ expected_slice = np.array(
+ [
+ 0.30078125,
+ 0.37890625,
+ 0.46875,
+ 0.28125,
+ 0.36914062,
+ 0.47851562,
+ 0.28515625,
+ 0.375,
+ 0.4765625,
+ 0.28125,
+ 0.375,
+ 0.48046875,
+ 0.27929688,
+ 0.37695312,
+ 0.47851562,
+ 0.27734375,
+ 0.38085938,
+ 0.4765625,
+ 0.2734375,
+ 0.38085938,
+ 0.47265625,
+ 0.27539062,
+ 0.37890625,
+ 0.47265625,
+ 0.27734375,
+ 0.37695312,
+ 0.47070312,
+ 0.27929688,
+ 0.37890625,
+ 0.47460938,
+ ],
+ dtype=np.float32,
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
+
+ assert max_diff < 1e-4
diff --git a/tests/pipelines/hunyuan_video/__init__.py b/tests/pipelines/hunyuan_video/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
new file mode 100644
index 000000000000..5802bde87a61
--- /dev/null
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
@@ -0,0 +1,366 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTokenizer,
+ LlamaConfig,
+ LlamaModel,
+ LlamaTokenizer,
+)
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideoImageToVideoPipeline,
+ HunyuanVideoTransformer3DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class HunyuanVideoImageToVideoPipelineFastTests(
+ PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase
+):
+ pipeline_class = HunyuanVideoImageToVideoPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["prompt", "image"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = HunyuanVideoTransformer3DModel(
+ in_channels=2 * 4 + 1,
+ out_channels=4,
+ num_attention_heads=2,
+ attention_head_dim=10,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ num_refiner_layers=1,
+ patch_size=1,
+ patch_size_t=1,
+ guidance_embeds=False,
+ text_embed_dim=16,
+ pooled_projection_dim=8,
+ rope_axes_dim=(2, 4, 4),
+ image_condition_type="latent_concat",
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ down_block_types=(
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ up_block_types=(
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ layers_per_block=1,
+ act_fn="silu",
+ norm_num_groups=4,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ temporal_compression_ratio=4,
+ mid_block_add_attention=True,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ llama_text_encoder_config = LlamaConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=16,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=8,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = LlamaModel(llama_text_encoder_config)
+ tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(
+ crop_size=336,
+ do_center_crop=True,
+ do_normalize=True,
+ do_resize=True,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ resample=3,
+ size=336,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "prompt_template": {
+ "template": "{}",
+ "crop_start": 0,
+ },
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.5,
+ "height": image_height,
+ "width": image_width,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ # NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline
+ self.assertEqual(generated_video.shape, (5, 3, 16, 16))
+ expected_video = torch.randn(5, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ # Seems to require higher tolerance than the other tests
+ expected_diff_max = 0.6
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skip(
+ "Encode prompt currently does not work in isolation because of requiring image embeddings from image processor. The test does not handle this case, or we need to rewrite encode_prompt."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
new file mode 100644
index 000000000000..bd3190de532d
--- /dev/null
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
@@ -0,0 +1,338 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConfig, LlamaModel, LlamaTokenizer
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanSkyreelsImageToVideoPipeline,
+ HunyuanVideoTransformer3DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class HunyuanSkyreelsImageToVideoPipelineFastTests(
+ PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase
+):
+ pipeline_class = HunyuanSkyreelsImageToVideoPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["prompt", "image"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = HunyuanVideoTransformer3DModel(
+ in_channels=8,
+ out_channels=4,
+ num_attention_heads=2,
+ attention_head_dim=10,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ num_refiner_layers=1,
+ patch_size=1,
+ patch_size_t=1,
+ guidance_embeds=True,
+ text_embed_dim=16,
+ pooled_projection_dim=8,
+ rope_axes_dim=(2, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ down_block_types=(
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ up_block_types=(
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ layers_per_block=1,
+ act_fn="silu",
+ norm_num_groups=4,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ temporal_compression_ratio=4,
+ mid_block_add_attention=True,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ llama_text_encoder_config = LlamaConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=16,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=8,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = LlamaModel(llama_text_encoder_config)
+ tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "prompt_template": {
+ "template": "{}",
+ "crop_start": 0,
+ },
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.5,
+ "height": 16,
+ "width": 16,
+ # 4 * k + 1 is the recommendation
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ # Seems to require higher tolerance than the other tests
+ expected_diff_max = 0.6
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
new file mode 100644
index 000000000000..aa4f045966c3
--- /dev/null
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
@@ -0,0 +1,347 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConfig, LlamaModel, LlamaTokenizer
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideoPipeline,
+ HunyuanVideoTransformer3DModel,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ to_np,
+)
+
+
+enable_full_determinism()
+
+
+class HunyuanVideoPipelineFastTests(
+ PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+):
+ pipeline_class = HunyuanVideoPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = HunyuanVideoTransformer3DModel(
+ in_channels=4,
+ out_channels=4,
+ num_attention_heads=2,
+ attention_head_dim=10,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ num_refiner_layers=1,
+ patch_size=1,
+ patch_size_t=1,
+ guidance_embeds=True,
+ text_embed_dim=16,
+ pooled_projection_dim=8,
+ rope_axes_dim=(2, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ down_block_types=(
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ up_block_types=(
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ layers_per_block=1,
+ act_fn="silu",
+ norm_num_groups=4,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ temporal_compression_ratio=4,
+ mid_block_add_attention=True,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ llama_text_encoder_config = LlamaConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=16,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=8,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = LlamaModel(llama_text_encoder_config)
+ tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "prompt_template": {
+ "template": "{}",
+ "crop_start": 0,
+ },
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.5,
+ "height": 16,
+ "width": 16,
+ # 4 * k + 1 is the recommendation
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ # Seems to require higher tolerance than the other tests
+ expected_diff_max = 0.6
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/hunyuandit/__init__.py b/tests/pipelines/hunyuandit/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuandit/test_hunyuan_dit.py
similarity index 96%
rename from tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
rename to tests/pipelines/hunyuandit/test_hunyuan_dit.py
index 653cb41e4bc4..5b1a82eda227 100644
--- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py
+++ b/tests/pipelines/hunyuandit/test_hunyuan_dit.py
@@ -30,7 +30,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -55,6 +55,7 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
+ test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -127,10 +128,12 @@ def test_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ @unittest.skip("Not supported.")
def test_sequential_cpu_offload_forward_pass(self):
# TODO(YiYi) need to fix later
pass
+ @unittest.skip("Not supported.")
def test_sequential_offload_forward_pass_twice(self):
# TODO(YiYi) need to fix later
pass
@@ -140,6 +143,76 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-3,
)
+ def test_feed_forward_chunking(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_no_chunking = image[0, -3:, -3:, -1]
+
+ pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_chunking = image[0, -3:, -3:, -1]
+
+ max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
+ self.assertLess(max_diff, 1e-4)
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["return_dict"] = False
+ image = pipe(**inputs)[0]
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ pipe.transformer.fuse_qkv_projections()
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["return_dict"] = False
+ image_fused = pipe(**inputs)[0]
+ image_slice_fused = image_fused[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ inputs["return_dict"] = False
+ image_disabled = pipe(**inputs)[0]
+ image_slice_disabled = image_disabled[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
+ @unittest.skip(
+ "Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
def test_save_load_optional_components(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -233,73 +306,9 @@ def test_save_load_optional_components(self):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4)
- def test_feed_forward_chunking(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = pipe(**inputs).images
- image_slice_no_chunking = image[0, -3:, -3:, -1]
-
- pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
- inputs = self.get_dummy_inputs(device)
- image = pipe(**inputs).images
- image_slice_chunking = image[0, -3:, -3:, -1]
-
- max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
- self.assertLess(max_diff, 1e-4)
-
- def test_fused_qkv_projections(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["return_dict"] = False
- image = pipe(**inputs)[0]
- original_image_slice = image[0, -3:, -3:, -1]
-
- pipe.transformer.fuse_qkv_projections()
- # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
- # to the pipeline level.
- pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
-
- inputs = self.get_dummy_inputs(device)
- inputs["return_dict"] = False
- image_fused = pipe(**inputs)[0]
- image_slice_fused = image_fused[0, -3:, -3:, -1]
-
- pipe.transformer.unfuse_qkv_projections()
- inputs = self.get_dummy_inputs(device)
- inputs["return_dict"] = False
- image_disabled = pipe(**inputs)[0]
- image_slice_disabled = image_disabled[0, -3:, -3:, -1]
-
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
-
@slow
-@require_torch_gpu
+@require_torch_accelerator
class HunyuanDiTPipelineIntegrationTests(unittest.TestCase):
prompt = "一个宇航员在骑马"
@@ -319,7 +328,7 @@ def test_hunyuan_dit_1024(self):
pipe = HunyuanDiTPipeline.from_pretrained(
"XCLiu/HunyuanDiT-0523", revision="refs/pr/2", torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
image = pipe(
diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
index 592ebd35f4a9..868a40c9fb53 100644
--- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
+++ b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
@@ -36,10 +36,11 @@
from diffusers.models.unets import I2VGenXLUNet
from diffusers.utils import is_xformers_available, load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
slow,
torch_device,
@@ -59,6 +60,9 @@ class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unit
# No `output_type`.
required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
+ supports_dduf = False
+ test_layerwise_casting = True
+
def get_dummy_components(self):
torch.manual_seed(0)
scheduler = DDIMScheduler(
@@ -224,25 +228,29 @@ def test_num_videos_per_prompt(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ @unittest.skip("Test not supported for now.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class I2VGenXLPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_i2vgen_xl(self):
pipe = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
index a8180a3bc27f..d5d4c20e471f 100644
--- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
+++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
@@ -34,11 +34,12 @@
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
is_flaky,
load_pt,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -54,13 +55,13 @@ def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_image_encoder(self, repo_id, subfolder):
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
@@ -165,7 +166,7 @@ def get_dummy_inputs(
@slow
-@require_torch_gpu
+@require_torch_accelerator
class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
def test_text_to_image(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
@@ -280,7 +281,7 @@ def test_text_to_image_model_cpu_offload(self):
inputs = self.get_dummy_inputs()
output_without_offload = pipeline(**inputs).images
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs()
output_with_offload = pipeline(**inputs).images
max_diff = np.abs(output_with_offload - output_without_offload).max()
@@ -376,9 +377,10 @@ def test_text_to_image_face_id(self):
pipeline.set_ip_adapter_scale(0.7)
inputs = self.get_dummy_inputs()
- id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
- 0
- ]
+ id_embeds = load_pt(
+ "https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt",
+ map_location=torch_device,
+ )[0]
id_embeds = id_embeds.reshape((2, 1, 1, 512))
inputs["ip_adapter_image_embeds"] = [id_embeds]
inputs["ip_adapter_image"] = None
@@ -391,7 +393,7 @@ def test_text_to_image_face_id(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
def test_text_to_image_sdxl(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
@@ -403,7 +405,7 @@ def test_text_to_image_sdxl(self):
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs()
@@ -461,7 +463,7 @@ def test_image_to_image_sdxl(self):
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs(for_image_to_image=True)
@@ -530,7 +532,7 @@ def test_inpainting_sdxl(self):
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs(for_inpainting=True)
@@ -578,7 +580,7 @@ def test_ip_adapter_mask(self):
image_encoder=image_encoder,
torch_dtype=self.dtype,
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus-face_sdxl_vit-h.safetensors"
)
@@ -606,7 +608,7 @@ def test_ip_adapter_multiple_masks(self):
image_encoder=image_encoder,
torch_dtype=self.dtype,
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2
)
@@ -633,7 +635,7 @@ def test_instant_style_multiple_masks(self):
pipeline = StableDiffusionXLPipeline.from_pretrained(
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16"
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.load_ip_adapter(
["ostris/ip-composition-adapter", "h94/IP-Adapter"],
@@ -674,7 +676,7 @@ def test_ip_adapter_multiple_masks_one_adapter(self):
image_encoder=image_encoder,
torch_dtype=self.dtype,
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]
)
diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py
index 8553ed96e9e1..30144e37a9d4 100644
--- a/tests/pipelines/kandinsky/test_kandinsky.py
+++ b/tests/pipelines/kandinsky/test_kandinsky.py
@@ -24,10 +24,11 @@
from diffusers import DDIMScheduler, KandinskyPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -204,6 +205,8 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
test_xformers_attention = False
+ supports_dduf = False
+
def get_dummy_components(self):
dummy = Dummies()
return dummy.get_dummy_components()
@@ -244,7 +247,7 @@ def test_kandinsky(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -253,12 +256,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -273,19 +276,19 @@ def test_offloads(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_text2img(self):
expected_image = load_numpy(
@@ -304,7 +307,7 @@ def test_kandinsky_text2img(self):
prompt = "red cat, 4k photo"
- generator = torch.Generator(device="cuda").manual_seed(0)
+ generator = torch.Generator(device=torch_device).manual_seed(0)
image_emb, zero_image_emb = pipe_prior(
prompt,
generator=generator,
@@ -312,7 +315,7 @@ def test_kandinsky_text2img(self):
negative_prompt="",
).to_tuple()
- generator = torch.Generator(device="cuda").manual_seed(0)
+ generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipeline(
prompt,
image_embeds=image_emb,
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index 607a47e08e58..c5f27a9cc9a9 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -18,7 +18,7 @@
import numpy as np
from diffusers import KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyInpaintCombinedPipeline
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
from .test_kandinsky import Dummies
@@ -52,6 +52,8 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
]
test_xformers_attention = True
+ supports_dduf = False
+
def get_dummy_components(self):
dummy = Dummies()
prior_dummy = PriorDummies()
@@ -103,7 +105,7 @@ def test_kandinsky(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -112,12 +114,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -160,6 +162,8 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
]
test_xformers_attention = False
+ supports_dduf = False
+
def get_dummy_components(self):
dummy = Img2ImgDummies()
prior_dummy = PriorDummies()
@@ -209,7 +213,7 @@ def test_kandinsky(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -218,12 +222,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -269,6 +273,8 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
]
test_xformers_attention = False
+ supports_dduf = False
+
def get_dummy_components(self):
dummy = InpaintDummies()
prior_dummy = PriorDummies()
@@ -308,8 +314,6 @@ def test_kandinsky(self):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
- print(image_from_tuple_slice)
-
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593])
@@ -321,7 +325,7 @@ def test_kandinsky(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -330,12 +334,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
index ea289c5ccd71..26361ce18b82 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
@@ -32,12 +32,13 @@
)
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
nightly,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -226,6 +227,8 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
test_xformers_attention = False
+ supports_dduf = False
+
def get_dummy_components(self):
dummies = Dummies()
return dummies.get_dummy_components()
@@ -265,7 +268,7 @@ def test_kandinsky_img2img(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -297,19 +300,19 @@ def test_dict_tuple_outputs_equivalent(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_img2img(self):
expected_image = load_numpy(
@@ -363,19 +366,19 @@ def test_kandinsky_img2img(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyImg2ImgPipelineNightlyTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_img2img_ddpm(self):
expected_image = load_numpy(
diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
index 740046678744..e30c601b6011 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
@@ -25,12 +25,13 @@
from diffusers import DDIMScheduler, KandinskyInpaintPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
nightly,
- require_torch_gpu,
+ require_torch_accelerator,
torch_device,
)
@@ -220,6 +221,8 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
test_xformers_attention = False
+ supports_dduf = False
+
def get_dummy_components(self):
dummies = Dummies()
return dummies.get_dummy_components()
@@ -263,7 +266,7 @@ def test_kandinsky_inpaint(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -272,12 +275,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -295,19 +298,19 @@ def test_float16_inference(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyInpaintPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_inpaint(self):
expected_image = load_numpy(
diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py
index 5f42447bd9d5..abb53bfb792f 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_prior.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py
@@ -184,6 +184,8 @@ class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
test_xformers_attention = False
+ supports_dduf = False
+
def get_dummy_components(self):
dummy = Dummies()
return dummy.get_dummy_components()
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py
index cbd9166efada..fea49d47b7bb 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky.py
@@ -22,12 +22,14 @@
from diffusers import DDIMScheduler, KandinskyV22Pipeline, KandinskyV22PriorPipeline, UNet2DConditionModel, VQModel
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
@@ -221,19 +223,19 @@ def test_float16_inference(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_text2img(self):
expected_image = load_numpy(
@@ -244,12 +246,12 @@ def test_kandinsky_text2img(self):
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
)
- pipe_prior.enable_model_cpu_offload()
+ pipe_prior.enable_model_cpu_offload(device=torch_device)
pipeline = KandinskyV22Pipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
prompt = "red cat, 4k photo"
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
index dbba0831397b..90f8b2034109 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
@@ -22,7 +22,7 @@
KandinskyV22Img2ImgCombinedPipeline,
KandinskyV22InpaintCombinedPipeline,
)
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
from .test_kandinsky import Dummies
@@ -57,6 +57,8 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
test_xformers_attention = True
callback_cfg_params = ["image_embds"]
+ supports_dduf = False
+
def get_dummy_components(self):
dummy = Dummies()
prior_dummy = PriorDummies()
@@ -108,7 +110,7 @@ def test_kandinsky(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -117,12 +119,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -181,6 +183,8 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
test_xformers_attention = False
callback_cfg_params = ["image_embds"]
+ supports_dduf = False
+
def get_dummy_components(self):
dummy = Img2ImgDummies()
prior_dummy = PriorDummies()
@@ -230,7 +234,7 @@ def test_kandinsky(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -239,12 +243,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -302,6 +306,8 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
]
test_xformers_attention = False
+ supports_dduf = False
+
def get_dummy_components(self):
dummy = InpaintDummies()
prior_dummy = PriorDummies()
@@ -351,7 +357,7 @@ def test_kandinsky(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -360,12 +366,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
index 26d8b45cf900..4702f473a992 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
@@ -29,13 +29,15 @@
VQModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
@@ -238,19 +240,19 @@ def test_float16_inference(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyV22Img2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_img2img(self):
expected_image = load_numpy(
@@ -266,12 +268,12 @@ def test_kandinsky_img2img(self):
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
)
- pipe_prior.enable_model_cpu_offload()
+ pipe_prior.enable_model_cpu_offload(device=torch_device)
pipeline = KandinskyV22Img2ImgPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
index 25cf4bbed456..9a7f659e533c 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
@@ -29,13 +29,14 @@
VQModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
is_flaky,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -292,19 +293,19 @@ def callback_inputs_test(pipe, i, t, callback_kwargs):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyV22InpaintPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_inpaint(self):
expected_image = load_numpy(
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
index be0bc238d4da..bdec6c132f80 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
@@ -186,6 +186,8 @@ class KandinskyV22PriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
callback_cfg_params = ["prompt_embeds", "text_encoder_hidden_states", "text_mask"]
test_xformers_attention = False
+ supports_dduf = False
+
def get_dummy_components(self):
dummies = Dummies()
return dummies.get_dummy_components()
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
index e898824e2d17..0ea32981d518 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
@@ -59,6 +59,8 @@ class KandinskyV22PriorEmb2EmbPipelineFastTests(PipelineTesterMixin, unittest.Te
]
test_xformers_attention = False
+ supports_dduf = False
+
@property
def text_embedder_hidden_size(self):
return 32
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py
index 941ef9093361..af1d45ff8975 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3.py
@@ -31,10 +31,12 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from ..pipeline_params import (
@@ -167,25 +169,25 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class Kandinsky3PipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinskyV3(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
@@ -211,7 +213,7 @@ def test_kandinskyV3_img2img(self):
pipe = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
index 8c817df32e0c..e00948621a06 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
@@ -31,10 +31,11 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -192,25 +193,25 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class Kandinsky3Img2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinskyV3_img2img(self):
pipe = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py
index de44af6d5908..218de2897e66 100644
--- a/tests/pipelines/kolors/test_kolors.py
+++ b/tests/pipelines/kolors/test_kolors.py
@@ -47,6 +47,9 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
+ supports_dduf = False
+ test_layerwise_casting = True
+
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -86,7 +89,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
- text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
+ text_encoder = ChatGLMModel.from_pretrained(
+ "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
+ )
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = {
diff --git a/tests/pipelines/kolors/test_kolors_img2img.py b/tests/pipelines/kolors/test_kolors_img2img.py
index 2010dbd7055a..89da95753a14 100644
--- a/tests/pipelines/kolors/test_kolors_img2img.py
+++ b/tests/pipelines/kolors/test_kolors_img2img.py
@@ -51,6 +51,8 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
+ supports_dduf = False
+
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
@@ -91,7 +93,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
- text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
+ text_encoder = ChatGLMModel.from_pretrained(
+ "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
+ )
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = {
@@ -150,3 +154,7 @@ def test_inference_batch_single_identical(self):
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=7e-2)
+
+ @unittest.skip("Test not supported because kolors img2img doesn't take pooled embeds as inputs unline kolors t2i.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
index b60a4553cded..570fa8fadf39 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
@@ -13,8 +13,9 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -213,13 +214,20 @@ def callback_inputs_test(pipe, i, t, callback_kwargs):
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class LatentConsistencyModelPipelineSlowTests(unittest.TestCase):
def setUp(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
index 386e60c54ac6..88e31a97aac5 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
@@ -14,10 +14,11 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -220,13 +221,20 @@ def callback_inputs_test(pipe, i, t, callback_kwargs):
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class LatentConsistencyModelImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
index 9b9a8ef65572..38ac6a46ccca 100644
--- a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
+++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
@@ -26,6 +26,7 @@
floats_tensor,
load_image,
nightly,
+ require_accelerator,
require_torch,
torch_device,
)
@@ -93,7 +94,7 @@ def test_inference_superresolution(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
+ @require_accelerator
def test_inference_superresolution_fp16(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py
index 9667ebff249d..80d370647f57 100644
--- a/tests/pipelines/latte/test_latte.py
+++ b/tests/pipelines/latte/test_latte.py
@@ -25,26 +25,36 @@
from diffusers import (
AutoencoderKL,
DDIMScheduler,
+ FasterCacheConfig,
LattePipeline,
LatteTransformer3DModel,
+ PyramidAttentionBroadcastConfig,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, to_np
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ to_np,
+)
enable_full_determinism()
-class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+class LattePipelineFastTests(
+ PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+):
pipeline_class = LattePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -52,12 +62,35 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ pab_config = PyramidAttentionBroadcastConfig(
+ spatial_attention_block_skip_range=2,
+ temporal_attention_block_skip_range=2,
+ cross_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(100, 700),
+ temporal_attention_timestep_skip_range=(100, 800),
+ cross_attention_timestep_skip_range=(100, 800),
+ spatial_attention_block_identifiers=["transformer_blocks"],
+ temporal_attention_block_identifiers=["temporal_transformer_blocks"],
+ cross_attention_block_identifiers=["transformer_blocks"],
+ )
- def get_dummy_components(self):
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ temporal_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ temporal_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LatteTransformer3DModel(
sample_size=8,
- num_layers=1,
+ num_layers=num_layers,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
@@ -187,9 +220,21 @@ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+ @unittest.skip("Not supported.")
def test_attention_slicing_forward_pass(self):
pass
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
+
+ @unittest.skip("Test not supported because `encode_prompt()` has multiple returns.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
def test_save_load_optional_components(self):
if not hasattr(self.pipeline_class, "_optional_components"):
return
@@ -257,34 +302,27 @@ def test_save_load_optional_components(self):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1.0)
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
-
@slow
-@require_torch_gpu
+@require_torch_accelerator
class LattePipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_latte(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
videos = pipe(
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
index effea2619749..342561d4f5e9 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
@@ -29,10 +29,11 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
slow,
torch_device,
@@ -146,7 +147,7 @@ def test_ledits_pp_inversion(self):
)
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([-0.9084, -0.0367, 0.2940, 0.0839, 0.6890, 0.2651, -0.7104, 2.1090, -0.7822])
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
@@ -167,12 +168,12 @@ def test_ledits_pp_inversion_batch(self):
)
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([0.2528, 0.1458, -0.2166, 0.4565, -0.5657, -1.0286, -0.9961, 0.5933, 1.1173])
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
latent_slice = sd_pipe.init_latents[1, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([-0.0796, 2.0583, 0.5501, 0.5358, 0.0282, -0.2803, -1.0470, 0.7023, -0.0072])
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
@@ -202,17 +203,17 @@ def test_ledits_pp_warmup_steps(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class LEditsPPPipelineStableDiffusionSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@classmethod
def setUpClass(cls):
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
index fcfd0aa51b9f..75795a33422b 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
@@ -41,7 +41,7 @@
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
slow,
torch_device,
@@ -216,14 +216,14 @@ def test_ledits_pp_inversion_batch(self):
)
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([0.2528, 0.1458, -0.2166, 0.4565, -0.5656, -1.0286, -0.9961, 0.5933, 1.1172])
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
latent_slice = sd_pipe.init_latents[1, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([-0.0796, 2.0583, 0.5500, 0.5358, 0.0282, -0.2803, -1.0470, 0.7024, -0.0072])
- print(latent_slice.flatten())
+
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
def test_ledits_pp_warmup_steps(self):
@@ -253,7 +253,7 @@ def test_ledits_pp_warmup_steps(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class LEditsPPPipelineStableDiffusionXLSlowTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
diff --git a/tests/pipelines/ltx/__init__.py b/tests/pipelines/ltx/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py
new file mode 100644
index 000000000000..4f72729fc9ce
--- /dev/null
+++ b/tests/pipelines/ltx/test_ltx.py
@@ -0,0 +1,267 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = LTXPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = LTXVideoTransformer3DModel(
+ in_channels=8,
+ out_channels=8,
+ patch_size=1,
+ patch_size_t=1,
+ num_attention_heads=4,
+ attention_head_dim=8,
+ cross_attention_dim=32,
+ num_layers=1,
+ caption_channels=32,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLLTXVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=8,
+ block_out_channels=(8, 8, 8, 8),
+ decoder_block_out_channels=(8, 8, 8, 8),
+ layers_per_block=(1, 1, 1, 1, 1),
+ decoder_layers_per_block=(1, 1, 1, 1, 1),
+ spatio_temporal_scaling=(True, True, False, False),
+ decoder_spatio_temporal_scaling=(True, True, False, False),
+ decoder_inject_noise=(False, False, False, False, False),
+ upsample_residual=(False, False, False, False),
+ upsample_factor=(1, 1, 1, 1),
+ timestep_conditioning=False,
+ patch_size=1,
+ patch_size_t=1,
+ encoder_causal=True,
+ decoder_causal=False,
+ )
+ vae.use_framewise_encoding = False
+ vae.use_framewise_decoding = False
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "height": 32,
+ "width": 32,
+ # 8 * k + 1 is the recommendation
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+ expected_video = torch.randn(9, 3, 32, 32)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/ltx/test_ltx_condition.py b/tests/pipelines/ltx/test_ltx_condition.py
new file mode 100644
index 000000000000..dbb9a740b433
--- /dev/null
+++ b/tests/pipelines/ltx/test_ltx_condition.py
@@ -0,0 +1,284 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLLTXVideo,
+ FlowMatchEulerDiscreteScheduler,
+ LTXConditionPipeline,
+ LTXVideoTransformer3DModel,
+)
+from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = LTXConditionPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = LTXVideoTransformer3DModel(
+ in_channels=8,
+ out_channels=8,
+ patch_size=1,
+ patch_size_t=1,
+ num_attention_heads=4,
+ attention_head_dim=8,
+ cross_attention_dim=32,
+ num_layers=1,
+ caption_channels=32,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLLTXVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=8,
+ block_out_channels=(8, 8, 8, 8),
+ decoder_block_out_channels=(8, 8, 8, 8),
+ layers_per_block=(1, 1, 1, 1, 1),
+ decoder_layers_per_block=(1, 1, 1, 1, 1),
+ spatio_temporal_scaling=(True, True, False, False),
+ decoder_spatio_temporal_scaling=(True, True, False, False),
+ decoder_inject_noise=(False, False, False, False, False),
+ upsample_residual=(False, False, False, False),
+ upsample_factor=(1, 1, 1, 1),
+ timestep_conditioning=False,
+ patch_size=1,
+ patch_size_t=1,
+ encoder_causal=True,
+ decoder_causal=False,
+ )
+ vae.use_framewise_encoding = False
+ vae.use_framewise_decoding = False
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0, use_conditions=False):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
+ if use_conditions:
+ conditions = LTXVideoCondition(
+ image=image,
+ )
+ else:
+ conditions = None
+
+ inputs = {
+ "conditions": conditions,
+ "image": None if use_conditions else image,
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "height": 32,
+ "width": 32,
+ # 8 * k + 1 is the recommendation
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs2 = self.get_dummy_inputs(device, use_conditions=True)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ video2 = pipe(**inputs2).frames
+ generated_video2 = video2[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ max_diff = np.abs(generated_video - generated_video2).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py
new file mode 100644
index 000000000000..1c3e018a8a4b
--- /dev/null
+++ b/tests/pipelines/ltx/test_ltx_image2video.py
@@ -0,0 +1,273 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLLTXVideo,
+ FlowMatchEulerDiscreteScheduler,
+ LTXImageToVideoPipeline,
+ LTXVideoTransformer3DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = LTXImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = LTXVideoTransformer3DModel(
+ in_channels=8,
+ out_channels=8,
+ patch_size=1,
+ patch_size_t=1,
+ num_attention_heads=4,
+ attention_head_dim=8,
+ cross_attention_dim=32,
+ num_layers=1,
+ caption_channels=32,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLLTXVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=8,
+ block_out_channels=(8, 8, 8, 8),
+ decoder_block_out_channels=(8, 8, 8, 8),
+ layers_per_block=(1, 1, 1, 1, 1),
+ decoder_layers_per_block=(1, 1, 1, 1, 1),
+ spatio_temporal_scaling=(True, True, False, False),
+ decoder_spatio_temporal_scaling=(True, True, False, False),
+ decoder_inject_noise=(False, False, False, False, False),
+ upsample_residual=(False, False, False, False),
+ upsample_factor=(1, 1, 1, 1),
+ timestep_conditioning=False,
+ patch_size=1,
+ patch_size_t=1,
+ encoder_causal=True,
+ decoder_causal=False,
+ )
+ vae.use_framewise_encoding = False
+ vae.use_framewise_decoding = False
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
+
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "height": 32,
+ "width": 32,
+ # 8 * k + 1 is the recommendation
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+ expected_video = torch.randn(9, 3, 32, 32)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py
index 5fd0dbf06050..0c1fe8eb2fcd 100644
--- a/tests/pipelines/lumina/test_lumina_nextdit.py
+++ b/tests/pipelines/lumina/test_lumina_nextdit.py
@@ -5,10 +5,17 @@
import torch
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
-from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ LuminaNextDiT2DModel,
+ LuminaPipeline,
+ LuminaText2ImgPipeline,
+)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -16,8 +23,8 @@
from ..test_pipelines_common import PipelineTesterMixin
-class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
- pipeline_class = LuminaText2ImgPipeline
+class LuminaPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = LuminaPipeline
params = frozenset(
[
"prompt",
@@ -31,6 +38,10 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
)
batch_params = frozenset(["prompt", "negative_prompt"])
+ supports_dduf = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
def get_dummy_components(self):
torch.manual_seed(0)
transformer = LuminaNextDiT2DModel(
@@ -90,55 +101,32 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
- def test_lumina_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- do_classifier_free_guidance = inputs["guidance_scale"] > 1
- (
- prompt_embeds,
- prompt_attention_mask,
- negative_prompt_embeds,
- negative_prompt_attention_mask,
- ) = pipe.encode_prompt(
- prompt,
- do_classifier_free_guidance=do_classifier_free_guidance,
- device=torch_device,
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- prompt_attention_mask=prompt_attention_mask,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
@unittest.skip("xformers attention processor does not exist for Lumina")
def test_xformers_attention_forwardGenerator_pass(self):
pass
+ def test_deprecation_raises_warning(self):
+ with self.assertWarns(FutureWarning) as warning:
+ _ = LuminaText2ImgPipeline(**self.get_dummy_components()).to(torch_device)
+ warning_message = str(warning.warnings[0].message)
+ assert "renamed to `LuminaPipeline`" in warning_message
+
@slow
-@require_torch_gpu
-class LuminaText2ImgPipelineSlowTests(unittest.TestCase):
- pipeline_class = LuminaText2ImgPipeline
+@require_torch_accelerator
+class LuminaPipelineSlowTests(unittest.TestCase):
+ pipeline_class = LuminaPipeline
repo_id = "Alpha-VLLM/Lumina-Next-SFT-diffusers"
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
@@ -156,7 +144,7 @@ def get_inputs(self, device, seed=0):
def test_lumina_inference(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
diff --git a/tests/pipelines/lumina2/__init__.py b/tests/pipelines/lumina2/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py
new file mode 100644
index 000000000000..33fc870bcd34
--- /dev/null
+++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py
@@ -0,0 +1,125 @@
+import unittest
+
+import torch
+from transformers import AutoTokenizer, Gemma2Config, Gemma2Model
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ Lumina2Pipeline,
+ Lumina2Text2ImgPipeline,
+ Lumina2Transformer2DModel,
+)
+from diffusers.utils.testing_utils import torch_device
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+class Lumina2PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = Lumina2Pipeline
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "guidance_scale",
+ "negative_prompt",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = Lumina2Transformer2DModel(
+ sample_size=4,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=8,
+ num_layers=2,
+ num_attention_heads=1,
+ num_kv_heads=1,
+ multiple_of=16,
+ ffn_dim_multiplier=None,
+ norm_eps=1e-5,
+ scaling_factor=1.0,
+ axes_dim_rope=[4, 2, 2],
+ cap_feat_dim=8,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ torch.manual_seed(0)
+ config = Gemma2Config(
+ head_dim=4,
+ hidden_size=8,
+ intermediate_size=8,
+ num_attention_heads=2,
+ num_hidden_layers=2,
+ num_key_value_heads=2,
+ sliding_window=2,
+ )
+ text_encoder = Gemma2Model(config)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae.eval(),
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_deprecation_raises_warning(self):
+ with self.assertWarns(FutureWarning) as warning:
+ _ = Lumina2Text2ImgPipeline(**self.get_dummy_components()).to(torch_device)
+ warning_message = str(warning.warnings[0].message)
+ assert "renamed to `Lumina2Pipeline`" in warning_message
diff --git a/tests/pipelines/marigold/test_marigold_depth.py b/tests/pipelines/marigold/test_marigold_depth.py
index fcb9adca7a7b..13f9a421861b 100644
--- a/tests/pipelines/marigold/test_marigold_depth.py
+++ b/tests/pipelines/marigold/test_marigold_depth.py
@@ -1,5 +1,5 @@
-# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
+# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# limitations under the License.
# --------------------------------------------------------------------------
# More information and citation instructions are available on the
-# Marigold project website: https://marigoldmonodepth.github.io
+# Marigold project website: https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
import gc
import random
@@ -32,12 +32,14 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
is_flaky,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
@@ -288,17 +290,17 @@ def test_marigold_depth_dummy_no_processing_resolution(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def _test_marigold_depth(
self,
@@ -317,8 +319,7 @@ def _test_marigold_depth(
from_pretrained_kwargs["torch_dtype"] = torch.float16
pipe = MarigoldDepthPipeline.from_pretrained(model_id, **from_pretrained_kwargs)
- if device == "cuda":
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(generator_seed)
@@ -358,7 +359,7 @@ def test_marigold_depth_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=False,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.1244, 0.1265, 0.1292, 0.1240, 0.1252, 0.1266, 0.1246, 0.1226, 0.1180]),
num_inference_steps=1,
@@ -371,7 +372,7 @@ def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.1241, 0.1262, 0.1290, 0.1238, 0.1250, 0.1265, 0.1244, 0.1225, 0.1179]),
num_inference_steps=1,
@@ -384,7 +385,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=2024,
expected_slice=np.array([0.1710, 0.1725, 0.1738, 0.1700, 0.1700, 0.1696, 0.1698, 0.1663, 0.1592]),
num_inference_steps=1,
@@ -397,7 +398,7 @@ def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
num_inference_steps=2,
@@ -410,7 +411,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.2683, 0.2693, 0.2698, 0.2666, 0.2632, 0.2615, 0.2656, 0.2603, 0.2573]),
num_inference_steps=1,
@@ -423,7 +424,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.1200, 0.1215, 0.1237, 0.1193, 0.1197, 0.1202, 0.1196, 0.1166, 0.1109]),
num_inference_steps=1,
@@ -437,7 +438,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
self._test_marigold_depth(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.1121, 0.1135, 0.1155, 0.1111, 0.1115, 0.1118, 0.1111, 0.1079, 0.1019]),
num_inference_steps=1,
@@ -451,7 +452,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
self._test_marigold_depth(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.2671, 0.2690, 0.2720, 0.2659, 0.2676, 0.2739, 0.2664, 0.2686, 0.2573]),
num_inference_steps=1,
diff --git a/tests/pipelines/marigold/test_marigold_intrinsics.py b/tests/pipelines/marigold/test_marigold_intrinsics.py
new file mode 100644
index 000000000000..b24e686a4dfe
--- /dev/null
+++ b/tests/pipelines/marigold/test_marigold_intrinsics.py
@@ -0,0 +1,571 @@
+# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
+# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# --------------------------------------------------------------------------
+# More information and citation instructions are available on the
+# Marigold project website: https://marigoldcomputervision.github.io
+# --------------------------------------------------------------------------
+import gc
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ AutoencoderTiny,
+ DDIMScheduler,
+ MarigoldIntrinsicsPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ load_image,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class MarigoldIntrinsicsPipelineTesterMixin(PipelineTesterMixin):
+ def _test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ additional_params_copy_to_batched_inputs=["num_inference_steps"],
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for components in pipe.components.values():
+ if hasattr(components, "set_default_attn_processor"):
+ components.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = diffusers.logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+ batched_inputs[name][-1] = 100 * "very long"
+
+ else:
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ for arg in additional_params_copy_to_batched_inputs:
+ batched_inputs[arg] = inputs[arg]
+
+ output = pipe(**inputs)
+ output_batch = pipe(**batched_inputs)
+
+ assert output_batch[0].shape[0] == batch_size * output[0].shape[0] # only changed here
+
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
+ assert max_diff < expected_max_diff
+
+ def _test_inference_batch_consistent(
+ self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True
+ ):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["generator"] = self.get_generator(0)
+
+ logger = diffusers.logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # prepare batched inputs
+ batched_inputs = []
+ for batch_size in batch_sizes:
+ batched_input = {}
+ batched_input.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ if name == "prompt":
+ len_prompt = len(value)
+ # make unequal batch sizes
+ batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
+
+ # make last batch super long
+ batched_input[name][-1] = 100 * "very long"
+
+ else:
+ batched_input[name] = batch_size * [value]
+
+ if batch_generator and "generator" in inputs:
+ batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_input["batch_size"] = batch_size
+
+ batched_inputs.append(batched_input)
+
+ logger.setLevel(level=diffusers.logging.WARNING)
+ for batch_size, batched_input in zip(batch_sizes, batched_inputs):
+ output = pipe(**batched_input)
+ assert len(output[0]) == batch_size * pipe.n_targets # only changed here
+
+
+class MarigoldIntrinsicsPipelineFastTests(MarigoldIntrinsicsPipelineTesterMixin, unittest.TestCase):
+ pipeline_class = MarigoldIntrinsicsPipeline
+ params = frozenset(["image"])
+ batch_params = frozenset(["image"])
+ image_params = frozenset(["image"])
+ image_latents_params = frozenset(["latents"])
+ callback_cfg_params = frozenset([])
+ test_xformers_attention = False
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "output_type",
+ ]
+ )
+
+ def get_dummy_components(self, time_cond_proj_dim=None):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ time_cond_proj_dim=time_cond_proj_dim,
+ sample_size=32,
+ in_channels=12,
+ out_channels=8,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ torch.manual_seed(0)
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ prediction_type="v_prediction",
+ set_alpha_to_one=False,
+ steps_offset=1,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ thresholding=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "prediction_type": "intrinsics",
+ }
+ return components
+
+ def get_dummy_tiny_autoencoder(self):
+ return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image / 2 + 0.5
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "image": image,
+ "num_inference_steps": 1,
+ "processing_resolution": 0,
+ "generator": generator,
+ "output_type": "np",
+ }
+ return inputs
+
+ def _test_marigold_intrinsics(
+ self,
+ generator_seed: int = 0,
+ expected_slice: np.ndarray = None,
+ atol: float = 1e-4,
+ **pipe_kwargs,
+ ):
+ device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe_inputs = self.get_dummy_inputs(device, seed=generator_seed)
+ pipe_inputs.update(**pipe_kwargs)
+
+ prediction = pipe(**pipe_inputs).prediction
+
+ prediction_slice = prediction[0, -3:, -3:, -1].flatten()
+
+ if pipe_inputs.get("match_input_resolution", True):
+ self.assertEqual(prediction.shape, (2, 32, 32, 3), "Unexpected output resolution")
+ else:
+ self.assertTrue(prediction.shape[0] == 2 and prediction.shape[3] == 3, "Unexpected output dimensions")
+ self.assertEqual(
+ max(prediction.shape[1:3]),
+ pipe_inputs.get("processing_resolution", 768),
+ "Unexpected output resolution",
+ )
+
+ np.set_printoptions(precision=5, suppress=True)
+ msg = f"{prediction_slice}"
+ self.assertTrue(np.allclose(prediction_slice, expected_slice, atol=atol), msg)
+ # self.assertTrue(np.allclose(prediction_slice, expected_slice, atol=atol))
+
+ def test_marigold_depth_dummy_defaults(self):
+ self._test_marigold_intrinsics(
+ expected_slice=np.array([0.6423, 0.40664, 0.41185, 0.65832, 0.63935, 0.43971, 0.51786, 0.55216, 0.47683]),
+ )
+
+ def test_marigold_depth_dummy_G0_S1_P32_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ generator_seed=0,
+ expected_slice=np.array([0.6423, 0.40664, 0.41185, 0.65832, 0.63935, 0.43971, 0.51786, 0.55216, 0.47683]),
+ num_inference_steps=1,
+ processing_resolution=32,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_depth_dummy_G0_S1_P16_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ generator_seed=0,
+ expected_slice=np.array([0.53132, 0.44487, 0.40164, 0.5326, 0.49073, 0.46979, 0.53324, 0.51366, 0.50387]),
+ num_inference_steps=1,
+ processing_resolution=16,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_depth_dummy_G2024_S1_P32_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ generator_seed=2024,
+ expected_slice=np.array([0.40257, 0.39468, 0.51373, 0.4161, 0.40162, 0.58535, 0.43581, 0.47834, 0.48951]),
+ num_inference_steps=1,
+ processing_resolution=32,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_depth_dummy_G0_S2_P32_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ generator_seed=0,
+ expected_slice=np.array([0.49636, 0.4518, 0.42722, 0.59044, 0.6362, 0.39011, 0.53522, 0.55153, 0.48699]),
+ num_inference_steps=2,
+ processing_resolution=32,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_depth_dummy_G0_S1_P64_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ generator_seed=0,
+ expected_slice=np.array([0.55547, 0.43511, 0.4887, 0.56399, 0.63867, 0.56337, 0.47889, 0.52925, 0.49235]),
+ num_inference_steps=1,
+ processing_resolution=64,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_depth_dummy_G0_S1_P32_E3_B1_M1(self):
+ self._test_marigold_intrinsics(
+ generator_seed=0,
+ expected_slice=np.array([0.57249, 0.49824, 0.54438, 0.57733, 0.52404, 0.5255, 0.56493, 0.56336, 0.48579]),
+ num_inference_steps=1,
+ processing_resolution=32,
+ ensemble_size=3,
+ ensembling_kwargs={"reduction": "mean"},
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_depth_dummy_G0_S1_P32_E4_B2_M1(self):
+ self._test_marigold_intrinsics(
+ generator_seed=0,
+ expected_slice=np.array([0.6294, 0.5575, 0.53414, 0.61077, 0.57156, 0.53974, 0.52956, 0.55467, 0.48751]),
+ num_inference_steps=1,
+ processing_resolution=32,
+ ensemble_size=4,
+ ensembling_kwargs={"reduction": "mean"},
+ batch_size=2,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_depth_dummy_G0_S1_P16_E1_B1_M0(self):
+ self._test_marigold_intrinsics(
+ generator_seed=0,
+ expected_slice=np.array([0.63511, 0.68137, 0.48783, 0.46689, 0.58505, 0.36757, 0.58465, 0.54302, 0.50387]),
+ num_inference_steps=1,
+ processing_resolution=16,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=False,
+ )
+
+ def test_marigold_depth_dummy_no_num_inference_steps(self):
+ with self.assertRaises(ValueError) as e:
+ self._test_marigold_intrinsics(
+ num_inference_steps=None,
+ expected_slice=np.array([0.0]),
+ )
+ self.assertIn("num_inference_steps", str(e))
+
+ def test_marigold_depth_dummy_no_processing_resolution(self):
+ with self.assertRaises(ValueError) as e:
+ self._test_marigold_intrinsics(
+ processing_resolution=None,
+ expected_slice=np.array([0.0]),
+ )
+ self.assertIn("processing_resolution", str(e))
+
+
+@slow
+@require_torch_gpu
+class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def _test_marigold_intrinsics(
+ self,
+ is_fp16: bool = True,
+ device: str = "cuda",
+ generator_seed: int = 0,
+ expected_slice: np.ndarray = None,
+ model_id: str = "prs-eth/marigold-iid-appearance-v1-1",
+ image_url: str = "https://marigoldmonodepth.github.io/images/einstein.jpg",
+ atol: float = 1e-4,
+ **pipe_kwargs,
+ ):
+ from_pretrained_kwargs = {}
+ if is_fp16:
+ from_pretrained_kwargs["variant"] = "fp16"
+ from_pretrained_kwargs["torch_dtype"] = torch.float16
+
+ pipe = MarigoldIntrinsicsPipeline.from_pretrained(model_id, **from_pretrained_kwargs)
+ if device == "cuda":
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device=device).manual_seed(generator_seed)
+
+ image = load_image(image_url)
+ width, height = image.size
+
+ prediction = pipe(image, generator=generator, **pipe_kwargs).prediction
+
+ prediction_slice = prediction[0, -3:, -3:, -1].flatten()
+
+ if pipe_kwargs.get("match_input_resolution", True):
+ self.assertEqual(prediction.shape, (2, height, width, 3), "Unexpected output resolution")
+ else:
+ self.assertTrue(prediction.shape[0] == 2 and prediction.shape[3] == 3, "Unexpected output dimensions")
+ self.assertEqual(
+ max(prediction.shape[1:3]),
+ pipe_kwargs.get("processing_resolution", 768),
+ "Unexpected output resolution",
+ )
+
+ msg = f"{prediction_slice}"
+ self.assertTrue(np.allclose(prediction_slice, expected_slice, atol=atol), msg)
+ # self.assertTrue(np.allclose(prediction_slice, expected_slice, atol=atol))
+
+ def test_marigold_intrinsics_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ is_fp16=False,
+ device="cpu",
+ generator_seed=0,
+ expected_slice=np.array([0.9162, 0.9162, 0.9162, 0.9162, 0.9162, 0.9162, 0.9162, 0.9162, 0.9162]),
+ num_inference_steps=1,
+ processing_resolution=32,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_intrinsics_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ is_fp16=False,
+ device="cuda",
+ generator_seed=0,
+ expected_slice=np.array([0.62127, 0.61906, 0.61687, 0.61946, 0.61903, 0.61961, 0.61808, 0.62099, 0.62894]),
+ num_inference_steps=1,
+ processing_resolution=768,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ is_fp16=True,
+ device="cuda",
+ generator_seed=0,
+ expected_slice=np.array([0.62109, 0.61914, 0.61719, 0.61963, 0.61914, 0.61963, 0.61816, 0.62109, 0.62891]),
+ num_inference_steps=1,
+ processing_resolution=768,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_intrinsics_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ is_fp16=True,
+ device="cuda",
+ generator_seed=2024,
+ expected_slice=np.array([0.64111, 0.63916, 0.63623, 0.63965, 0.63916, 0.63965, 0.6377, 0.64062, 0.64941]),
+ num_inference_steps=1,
+ processing_resolution=768,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_intrinsics_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ is_fp16=True,
+ device="cuda",
+ generator_seed=0,
+ expected_slice=np.array([0.60254, 0.60059, 0.59961, 0.60156, 0.60107, 0.60205, 0.60254, 0.60449, 0.61133]),
+ num_inference_steps=2,
+ processing_resolution=768,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
+ self._test_marigold_intrinsics(
+ is_fp16=True,
+ device="cuda",
+ generator_seed=0,
+ expected_slice=np.array([0.64551, 0.64453, 0.64404, 0.64502, 0.64844, 0.65039, 0.64502, 0.65039, 0.65332]),
+ num_inference_steps=1,
+ processing_resolution=512,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
+ self._test_marigold_intrinsics(
+ is_fp16=True,
+ device="cuda",
+ generator_seed=0,
+ expected_slice=np.array([0.61572, 0.61377, 0.61182, 0.61426, 0.61377, 0.61426, 0.61279, 0.61572, 0.62354]),
+ num_inference_steps=1,
+ processing_resolution=768,
+ ensemble_size=3,
+ ensembling_kwargs={"reduction": "mean"},
+ batch_size=1,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
+ self._test_marigold_intrinsics(
+ is_fp16=True,
+ device="cuda",
+ generator_seed=0,
+ expected_slice=np.array([0.61914, 0.6167, 0.61475, 0.61719, 0.61719, 0.61768, 0.61572, 0.61914, 0.62695]),
+ num_inference_steps=1,
+ processing_resolution=768,
+ ensemble_size=4,
+ ensembling_kwargs={"reduction": "mean"},
+ batch_size=2,
+ match_input_resolution=True,
+ )
+
+ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
+ self._test_marigold_intrinsics(
+ is_fp16=True,
+ device="cuda",
+ generator_seed=0,
+ expected_slice=np.array([0.65332, 0.64697, 0.64648, 0.64844, 0.64697, 0.64111, 0.64941, 0.64209, 0.65332]),
+ num_inference_steps=1,
+ processing_resolution=512,
+ ensemble_size=1,
+ batch_size=1,
+ match_input_resolution=False,
+ )
diff --git a/tests/pipelines/marigold/test_marigold_normals.py b/tests/pipelines/marigold/test_marigold_normals.py
index c86c600be8e5..1797f99b213b 100644
--- a/tests/pipelines/marigold/test_marigold_normals.py
+++ b/tests/pipelines/marigold/test_marigold_normals.py
@@ -1,5 +1,5 @@
-# Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
+# Copyright 2024-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# limitations under the License.
# --------------------------------------------------------------------------
# More information and citation instructions are available on the
-# Marigold project website: https://marigoldmonodepth.github.io
+# Marigold project website: https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
import gc
import random
@@ -32,11 +32,13 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
@@ -285,17 +287,17 @@ def test_marigold_depth_dummy_no_processing_resolution(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class MarigoldNormalsPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def _test_marigold_normals(
self,
@@ -314,8 +316,7 @@ def _test_marigold_normals(
from_pretrained_kwargs["torch_dtype"] = torch.float16
pipe = MarigoldNormalsPipeline.from_pretrained(model_id, **from_pretrained_kwargs)
- if device == "cuda":
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(generator_seed)
@@ -342,7 +343,7 @@ def _test_marigold_normals(
def test_marigold_normals_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
self._test_marigold_normals(
is_fp16=False,
- device="cpu",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.8971, 0.8971, 0.8971, 0.8971, 0.8971, 0.8971, 0.8971, 0.8971, 0.8971]),
num_inference_steps=1,
@@ -355,7 +356,7 @@ def test_marigold_normals_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
def test_marigold_normals_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_normals(
is_fp16=False,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.7980, 0.7952, 0.7914, 0.7931, 0.7871, 0.7816, 0.7844, 0.7710, 0.7601]),
num_inference_steps=1,
@@ -368,7 +369,7 @@ def test_marigold_normals_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_normals(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.7979, 0.7949, 0.7915, 0.7930, 0.7871, 0.7817, 0.7842, 0.7710, 0.7603]),
num_inference_steps=1,
@@ -381,7 +382,7 @@ def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
def test_marigold_normals_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
self._test_marigold_normals(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=2024,
expected_slice=np.array([0.8428, 0.8428, 0.8433, 0.8369, 0.8325, 0.8315, 0.8271, 0.8135, 0.8057]),
num_inference_steps=1,
@@ -394,7 +395,7 @@ def test_marigold_normals_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
def test_marigold_normals_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
self._test_marigold_normals(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.7095, 0.7095, 0.7104, 0.7070, 0.7051, 0.7061, 0.7017, 0.6938, 0.6914]),
num_inference_steps=2,
@@ -407,7 +408,7 @@ def test_marigold_normals_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
def test_marigold_normals_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
self._test_marigold_normals(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.7168, 0.7163, 0.7163, 0.7080, 0.7061, 0.7046, 0.7031, 0.7007, 0.6987]),
num_inference_steps=1,
@@ -420,7 +421,7 @@ def test_marigold_normals_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
self._test_marigold_normals(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.7114, 0.7124, 0.7144, 0.7085, 0.7070, 0.7080, 0.7051, 0.6958, 0.6924]),
num_inference_steps=1,
@@ -434,7 +435,7 @@ def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
self._test_marigold_normals(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.7412, 0.7441, 0.7490, 0.7383, 0.7388, 0.7437, 0.7329, 0.7271, 0.7300]),
num_inference_steps=1,
@@ -448,7 +449,7 @@ def test_marigold_normals_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
def test_marigold_normals_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
self._test_marigold_normals(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.7188, 0.7144, 0.7134, 0.7178, 0.7207, 0.7222, 0.7231, 0.7041, 0.6987]),
num_inference_steps=1,
diff --git a/tests/pipelines/mochi/__init__.py b/tests/pipelines/mochi/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py
new file mode 100644
index 000000000000..ea2d015af52a
--- /dev/null
+++ b/tests/pipelines/mochi/test_mochi.py
@@ -0,0 +1,306 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import inspect
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_big_gpu_with_torch_cuda,
+ require_torch_gpu,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
+ pipeline_class = MochiPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self, num_layers: int = 2):
+ torch.manual_seed(0)
+ transformer = MochiTransformer3DModel(
+ patch_size=2,
+ num_attention_heads=2,
+ attention_head_dim=8,
+ num_layers=num_layers,
+ pooled_projection_dim=16,
+ in_channels=12,
+ out_channels=None,
+ qk_norm="rms_norm",
+ text_embed_dim=32,
+ time_embed_dim=4,
+ activation_fn="swiglu",
+ max_sequence_length=16,
+ )
+ transformer.pos_frequencies.data = transformer.pos_frequencies.new_full(transformer.pos_frequencies.shape, 0)
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLMochi(
+ latent_channels=12,
+ out_channels=3,
+ encoder_block_out_channels=(32, 32, 32, 32),
+ decoder_block_out_channels=(32, 32, 32, 32),
+ layers_per_block=(1, 1, 1, 1, 1),
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.5,
+ "height": 16,
+ "width": 16,
+ # 6 * k + 1 is the recommendation
+ "num_frames": 7,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (7, 3, 16, 16))
+ expected_video = torch.randn(7, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+
+@nightly
+@require_torch_gpu
+@require_big_gpu_with_torch_cuda
+@pytest.mark.big_gpu_with_torch_cuda
+class MochiPipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_mochi(self):
+ generator = torch.Generator("cpu").manual_seed(0)
+
+ pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float16)
+ pipe.enable_model_cpu_offload(device=torch_device)
+ prompt = self.prompt
+
+ videos = pipe(
+ prompt=prompt,
+ height=480,
+ width=848,
+ num_frames=19,
+ generator=generator,
+ num_inference_steps=2,
+ output_type="pt",
+ ).frames
+
+ video = videos[0]
+ expected_video = torch.randn(1, 19, 480, 848, 3).numpy()
+
+ max_diff = numpy_cosine_similarity_distance(video, expected_video)
+ assert max_diff < 1e-3, f"Max diff is too high. got {video}"
diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py
index e51f5103933a..bdd536b6ff86 100644
--- a/tests/pipelines/musicldm/test_musicldm.py
+++ b/tests/pipelines/musicldm/test_musicldm.py
@@ -65,6 +65,8 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
diff --git a/tests/pipelines/omnigen/__init__.py b/tests/pipelines/omnigen/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py
new file mode 100644
index 000000000000..2f9c4d4e3f8e
--- /dev/null
+++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py
@@ -0,0 +1,146 @@
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
+from diffusers.utils.testing_utils import (
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = OmniGenPipeline
+ params = frozenset(["prompt", "guidance_scale"])
+ batch_params = frozenset(["prompt"])
+
+ test_layerwise_casting = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+
+ transformer = OmniGenTransformer2DModel(
+ hidden_size=16,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ intermediate_size=32,
+ num_layers=1,
+ in_channels=4,
+ time_step_dim=4,
+ rope_scaling={"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4, 4, 4, 4),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 1,
+ "guidance_scale": 3.0,
+ "output_type": "np",
+ "height": 16,
+ "width": 16,
+ }
+ return inputs
+
+ def test_inference(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ generated_image = pipe(**inputs).images[0]
+
+ self.assertEqual(generated_image.shape, (16, 16, 3))
+
+
+@slow
+@require_torch_gpu
+class OmniGenPipelineSlowTests(unittest.TestCase):
+ pipeline_class = OmniGenPipeline
+ repo_id = "shitao/OmniGen-v1-diffusers"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ return {
+ "prompt": "A photo of a cat",
+ "num_inference_steps": 2,
+ "guidance_scale": 2.5,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ def test_omnigen_inference(self):
+ pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
+ pipe.enable_model_cpu_offload()
+
+ inputs = self.get_inputs(torch_device)
+
+ image = pipe(**inputs).images[0]
+ image_slice = image[0, :10, :10]
+
+ expected_slice = np.array(
+ [
+ [0.1783447, 0.16772744, 0.14339337],
+ [0.17066911, 0.15521264, 0.13757327],
+ [0.17072496, 0.15531206, 0.13524258],
+ [0.16746324, 0.1564025, 0.13794944],
+ [0.16490817, 0.15258026, 0.13697758],
+ [0.16971767, 0.15826806, 0.13928896],
+ [0.16782972, 0.15547255, 0.13783783],
+ [0.16464645, 0.15281534, 0.13522372],
+ [0.16535294, 0.15301755, 0.13526791],
+ [0.16365296, 0.15092957, 0.13443318],
+ ],
+ dtype=np.float32,
+ )
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
+
+ assert max_diff < 1e-4
diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py
index 7efe8002d17c..6fa96275406f 100644
--- a/tests/pipelines/pag/test_pag_animatediff.py
+++ b/tests/pipelines/pag/test_pag_animatediff.py
@@ -19,7 +19,7 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import torch_device
+from diffusers.utils.testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
@@ -218,7 +218,7 @@ def test_dict_tuple_outputs_equivalent(self):
expected_slice = np.array([0.5295, 0.3947, 0.5300, 0.4864, 0.4518, 0.5315, 0.5440, 0.4775, 0.5538])
return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -234,14 +234,14 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -553,3 +553,11 @@ def test_pag_applied_layers(self):
pag_layers = ["motion_modules.42"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd.py b/tests/pipelines/pag/test_pag_controlnet_sd.py
index 8a7eb6f0c675..ee97b0507a34 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd.py
@@ -28,9 +28,7 @@
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
-)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import (
@@ -246,3 +244,10 @@ def test_pag_uncond(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
index 0a7413e99926..25ef5d253d68 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
@@ -32,10 +32,7 @@
StableDiffusionControlNetPAGInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
-)
+from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import (
@@ -243,3 +240,10 @@ def test_pag_uncond(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl.py b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
index 6400cc2b7cab..0588e26286a8 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
@@ -42,7 +42,6 @@
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -54,7 +53,6 @@ class StableDiffusionXLControlNetPAGPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineFromPipeTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLControlNetPAGPipeline
@@ -214,9 +212,6 @@ def test_pag_disable_enable(self):
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
def test_pag_cfg(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -263,3 +258,7 @@ def test_pag_uncond(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
+
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
index b02f4d8b4561..63c7d9fbee2d 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
@@ -41,7 +41,6 @@
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -53,7 +52,6 @@ class StableDiffusionXLControlNetPAGImg2ImgPipelineFastTests(
PipelineLatentTesterMixin,
PipelineTesterMixin,
PipelineFromPipeTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLControlNetPAGImg2ImgPipeline
diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py
index db0e257760ed..31cd9aa666de 100644
--- a/tests/pipelines/pag/test_pag_hunyuan_dit.py
+++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py
@@ -28,10 +28,7 @@
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
)
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- torch_device,
-)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -121,10 +118,12 @@ def test_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ @unittest.skip("Not supported.")
def test_sequential_cpu_offload_forward_pass(self):
# TODO(YiYi) need to fix later
pass
+ @unittest.skip("Not supported.")
def test_sequential_offload_forward_pass_twice(self):
# TODO(YiYi) need to fix later
pass
@@ -134,99 +133,6 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-3,
)
- def test_save_load_optional_components(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
-
- prompt = inputs["prompt"]
- generator = inputs["generator"]
- num_inference_steps = inputs["num_inference_steps"]
- output_type = inputs["output_type"]
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- prompt_attention_mask,
- negative_prompt_attention_mask,
- ) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0)
-
- (
- prompt_embeds_2,
- negative_prompt_embeds_2,
- prompt_attention_mask_2,
- negative_prompt_attention_mask_2,
- ) = pipe.encode_prompt(
- prompt,
- device=torch_device,
- dtype=torch.float32,
- text_encoder_index=1,
- )
-
- # inputs with prompt converted to embeddings
- inputs = {
- "prompt_embeds": prompt_embeds,
- "prompt_attention_mask": prompt_attention_mask,
- "negative_prompt_embeds": negative_prompt_embeds,
- "negative_prompt_attention_mask": negative_prompt_attention_mask,
- "prompt_embeds_2": prompt_embeds_2,
- "prompt_attention_mask_2": prompt_attention_mask_2,
- "negative_prompt_embeds_2": negative_prompt_embeds_2,
- "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
- "generator": generator,
- "num_inference_steps": num_inference_steps,
- "output_type": output_type,
- "use_resolution_binning": False,
- }
-
- # set all optional components to None
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
-
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
-
- inputs = self.get_dummy_inputs(torch_device)
-
- generator = inputs["generator"]
- num_inference_steps = inputs["num_inference_steps"]
- output_type = inputs["output_type"]
-
- # inputs with prompt converted to embeddings
- inputs = {
- "prompt_embeds": prompt_embeds,
- "prompt_attention_mask": prompt_attention_mask,
- "negative_prompt_embeds": negative_prompt_embeds,
- "negative_prompt_attention_mask": negative_prompt_attention_mask,
- "prompt_embeds_2": prompt_embeds_2,
- "prompt_attention_mask_2": prompt_attention_mask_2,
- "negative_prompt_embeds_2": negative_prompt_embeds_2,
- "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
- "generator": generator,
- "num_inference_steps": num_inference_steps,
- "output_type": output_type,
- "use_resolution_binning": False,
- }
-
- output_loaded = pipe_loaded(**inputs)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(max_diff, 1e-4)
-
def test_feed_forward_chunking(self):
device = "cpu"
@@ -356,3 +262,102 @@ def test_pag_applied_layers(self):
pag_layers = ["blocks.0", r"blocks\.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2
+
+ @unittest.skip(
+ "Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ def test_save_load_optional_components(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ prompt = inputs["prompt"]
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ ) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0)
+
+ (
+ prompt_embeds_2,
+ negative_prompt_embeds_2,
+ prompt_attention_mask_2,
+ negative_prompt_attention_mask_2,
+ ) = pipe.encode_prompt(
+ prompt,
+ device=torch_device,
+ dtype=torch.float32,
+ text_encoder_index=1,
+ )
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "prompt_attention_mask": prompt_attention_mask,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "negative_prompt_attention_mask": negative_prompt_attention_mask,
+ "prompt_embeds_2": prompt_embeds_2,
+ "prompt_attention_mask_2": prompt_attention_mask_2,
+ "negative_prompt_embeds_2": negative_prompt_embeds_2,
+ "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ "use_resolution_binning": False,
+ }
+
+ # set all optional components to None
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "prompt_attention_mask": prompt_attention_mask,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "negative_prompt_attention_mask": negative_prompt_attention_mask,
+ "prompt_embeds_2": prompt_embeds_2,
+ "prompt_attention_mask_2": prompt_attention_mask_2,
+ "negative_prompt_embeds_2": negative_prompt_embeds_2,
+ "negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ "use_resolution_binning": False,
+ }
+
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, 1e-4)
diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py
index 8cfb2c3fd16a..9a4f1daa2c05 100644
--- a/tests/pipelines/pag/test_pag_kolors.py
+++ b/tests/pipelines/pag/test_pag_kolors.py
@@ -56,6 +56,8 @@ class KolorsPAGPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
+ supports_dduf = False
+
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
@@ -96,7 +98,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
- text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
+ text_encoder = ChatGLMModel.from_pretrained(
+ "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
+ )
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = {
@@ -250,3 +254,6 @@ def test_pag_inference(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=3e-3)
+
+ def test_encode_prompt_works_in_isolation(self):
+ return super().test_encode_prompt_works_in_isolation(atol=1e-3, rtol=1e-3)
diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py
index 7de19e0f00fc..63f42416dbca 100644
--- a/tests/pipelines/pag/test_pag_pixart_sigma.py
+++ b/tests/pipelines/pag/test_pag_pixart_sigma.py
@@ -184,82 +184,6 @@ def test_pag_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
- # Copied from tests.pipelines.pixart_sigma.test_pixart.PixArtSigmaPipelineFastTests.test_save_load_optional_components
- def test_save_load_optional_components(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
-
- prompt = inputs["prompt"]
- generator = inputs["generator"]
- num_inference_steps = inputs["num_inference_steps"]
- output_type = inputs["output_type"]
-
- (
- prompt_embeds,
- prompt_attention_mask,
- negative_prompt_embeds,
- negative_prompt_attention_mask,
- ) = pipe.encode_prompt(prompt)
-
- # inputs with prompt converted to embeddings
- inputs = {
- "prompt_embeds": prompt_embeds,
- "prompt_attention_mask": prompt_attention_mask,
- "negative_prompt": None,
- "negative_prompt_embeds": negative_prompt_embeds,
- "negative_prompt_attention_mask": negative_prompt_attention_mask,
- "generator": generator,
- "num_inference_steps": num_inference_steps,
- "output_type": output_type,
- "use_resolution_binning": False,
- }
-
- # set all optional components to None
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
-
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"])
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
-
- inputs = self.get_dummy_inputs(torch_device)
-
- generator = inputs["generator"]
- num_inference_steps = inputs["num_inference_steps"]
- output_type = inputs["output_type"]
-
- # inputs with prompt converted to embeddings
- inputs = {
- "prompt_embeds": prompt_embeds,
- "prompt_attention_mask": prompt_attention_mask,
- "negative_prompt": None,
- "negative_prompt_embeds": negative_prompt_embeds,
- "negative_prompt_attention_mask": negative_prompt_attention_mask,
- "generator": generator,
- "num_inference_steps": num_inference_steps,
- "output_type": output_type,
- "use_resolution_binning": False,
- }
-
- output_loaded = pipe_loaded(**inputs)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(max_diff, 1e-4)
-
# Because the PAG PixArt Sigma has `pag_applied_layers`.
# Also, we shouldn't be doing `set_default_attn_processor()` after loading
# the pipeline with `pag_applied_layers`.
@@ -419,3 +343,7 @@ def test_components_function(self):
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
+
+ @unittest.skip("Test is already covered through encode_prompt isolation.")
+ def test_save_load_optional_components(self):
+ pass
diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py
new file mode 100644
index 000000000000..a2c657297860
--- /dev/null
+++ b/tests/pipelines/pag/test_pag_sana.py
@@ -0,0 +1,341 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import Gemma2Config, Gemma2ForCausalLM, GemmaTokenizer
+
+from diffusers import (
+ AutoencoderDC,
+ FlowMatchEulerDiscreteScheduler,
+ SanaPAGPipeline,
+ SanaPipeline,
+ SanaTransformer2DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class SanaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaPAGPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = SanaTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ out_channels=4,
+ num_layers=2,
+ num_attention_heads=2,
+ attention_head_dim=4,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=4,
+ cross_attention_dim=8,
+ caption_channels=8,
+ sample_size=32,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderDC(
+ in_channels=3,
+ latent_channels=4,
+ attention_head_dim=2,
+ encoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ decoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ encoder_block_out_channels=(8, 8),
+ decoder_block_out_channels=(8, 8),
+ encoder_qkv_multiscales=((), (5,)),
+ decoder_qkv_multiscales=((), (5,)),
+ encoder_layers_per_block=(1, 1),
+ decoder_layers_per_block=[1, 1],
+ downsample_block_type="conv",
+ upsample_block_type="interpolate",
+ decoder_norm_types="rms_norm",
+ decoder_act_fns="silu",
+ scaling_factor=0.41407,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=32,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2ForCausalLM(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "pag_scale": 3.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": None,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.randn(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_pag_disable_enable(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ # base pipeline (expect same output when pag is disabled)
+ pipe_sd = SanaPipeline(**components)
+ pipe_sd = pipe_sd.to(device)
+ pipe_sd.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ del inputs["pag_scale"]
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
+
+ components = self.get_dummy_components()
+
+ # pag disabled with pag_scale=0.0
+ pipe_pag = self.pipeline_class(**components)
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["pag_scale"] = 0.0
+ out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
+
+ assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
+
+ def test_pag_applied_layers(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ # base pipeline
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k]
+ original_attn_procs = pipe.transformer.attn_processors
+ pag_layers = ["blocks.0", "blocks.1"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
+
+ # blocks.0
+ block_0_self_attn = ["transformer_blocks.0.attn1.processor"]
+ pipe.transformer.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["blocks.0"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
+
+ pipe.transformer.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["blocks.0.attn1"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
+
+ pipe.transformer.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["blocks.(0|1)"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert (len(pipe.pag_attn_processors)) == 2
+
+ pipe.transformer.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["blocks.0", r"blocks\.1"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert len(pipe.pag_attn_processors) == 2
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py
index 3979bb170e0b..d4cf00b034ff 100644
--- a/tests/pipelines/pag/test_pag_sd.py
+++ b/tests/pipelines/pag/test_pag_sd.py
@@ -30,8 +30,9 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -47,7 +48,6 @@
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -59,7 +59,6 @@ class StableDiffusionPAGPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineFromPipeTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionPAGPipeline
@@ -278,9 +277,16 @@ def test_pag_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusionPAGPipeline
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
@@ -288,12 +294,12 @@ class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", seed=1, guidance_scale=7.0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -310,7 +316,7 @@ def get_inputs(self, device, generator_device="cpu", seed=1, guidance_scale=7.0)
def test_pag_cfg(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
@@ -318,7 +324,7 @@ def test_pag_cfg(self):
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
- print(image_slice.flatten())
+
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
@@ -328,7 +334,7 @@ def test_pag_cfg(self):
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device, guidance_scale=0.0)
@@ -339,7 +345,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- print(image_slice.flatten())
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sd3.py b/tests/pipelines/pag/test_pag_sd3.py
index 627d613ee20d..41ff0c3c09f4 100644
--- a/tests/pipelines/pag/test_pag_sd3.py
+++ b/tests/pipelines/pag/test_pag_sd3.py
@@ -156,39 +156,6 @@ def test_stable_diffusion_3_different_negative_prompts(self):
# Outputs should be different here
assert max_diff > 1e-2
- def test_stable_diffusion_3_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- do_classifier_free_guidance = inputs["guidance_scale"] > 1
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- prompt_3=None,
- do_classifier_free_guidance=do_classifier_free_guidance,
- device=torch_device,
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd3_img2img.py b/tests/pipelines/pag/test_pag_sd3_img2img.py
new file mode 100644
index 000000000000..2fe988929185
--- /dev/null
+++ b/tests/pipelines/pag/test_pag_sd3_img2img.py
@@ -0,0 +1,277 @@
+import gc
+import inspect
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ AutoPipelineForImage2Image,
+ FlowMatchEulerDiscreteScheduler,
+ SD3Transformer2DModel,
+ StableDiffusion3Img2ImgPipeline,
+ StableDiffusion3PAGImg2ImgPipeline,
+)
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ floats_tensor,
+ load_image,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
+ TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
+)
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class StableDiffusion3PAGImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = StableDiffusion3PAGImg2ImgPipeline
+ params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) - {"height", "width"}
+ required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
+ batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latens_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
+
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = SD3Transformer2DModel(
+ sample_size=32,
+ patch_size=1,
+ in_channels=4,
+ num_layers=2,
+ attention_head_dim=8,
+ num_attention_heads=4,
+ caption_projection_dim=32,
+ joint_attention_dim=32,
+ pooled_projection_dim=64,
+ out_channels=4,
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
+
+ text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "text_encoder_3": text_encoder_3,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "tokenizer_3": tokenizer_3,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image / 2 + 0.5
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "np",
+ "pag_scale": 0.7,
+ }
+ return inputs
+
+ def test_pag_disable_enable(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ # base pipeline (expect same output when pag is disabled)
+ pipe_sd = StableDiffusion3Img2ImgPipeline(**components)
+ pipe_sd = pipe_sd.to(device)
+ pipe_sd.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ del inputs["pag_scale"]
+ assert (
+ "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
+ ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
+
+ components = self.get_dummy_components()
+
+ # pag disabled with pag_scale=0.0
+ pipe_pag = self.pipeline_class(**components)
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["pag_scale"] = 0.0
+ out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
+
+ assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
+
+ def test_pag_inference(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ pipe_pag = self.pipeline_class(**components, pag_applied_layers=["blocks.0"])
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe_pag(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (
+ 1,
+ 32,
+ 32,
+ 3,
+ ), f"the shape of the output image should be (1, 32, 32, 3) but got {image.shape}"
+
+ expected_slice = np.array(
+ [0.66063476, 0.44838923, 0.5484299, 0.7242875, 0.5970012, 0.6015729, 0.53080845, 0.52220416, 0.56397927]
+ )
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+
+@slow
+@require_torch_accelerator
+class StableDiffusion3PAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
+ pipeline_class = StableDiffusion3PAGImg2ImgPipeline
+ repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_inputs(
+ self, device, generator_device="cpu", dtype=torch.float32, seed=0, guidance_scale=7.0, pag_scale=0.7
+ ):
+ img_url = (
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+ )
+ init_image = load_image(img_url)
+
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
+ inputs = {
+ "prompt": "an astronaut in a space suit walking through a jungle",
+ "generator": generator,
+ "image": init_image,
+ "num_inference_steps": 12,
+ "strength": 0.6,
+ "guidance_scale": guidance_scale,
+ "pag_scale": pag_scale,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_pag_cfg(self):
+ pipeline = AutoPipelineForImage2Image.from_pretrained(
+ self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.17"]
+ )
+ pipeline.enable_model_cpu_offload(device=torch_device)
+ pipeline.set_progress_bar_config(disable=None)
+
+ inputs = self.get_inputs(torch_device)
+ image = pipeline(**inputs).images
+ image_slice = image[0, -3:, -3:, -1].flatten()
+ assert image.shape == (1, 1024, 1024, 3)
+ expected_slice = np.array(
+ [
+ 0.16772461,
+ 0.17626953,
+ 0.18432617,
+ 0.17822266,
+ 0.18359375,
+ 0.17626953,
+ 0.17407227,
+ 0.17700195,
+ 0.17822266,
+ ]
+ )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
+
+ def test_pag_uncond(self):
+ pipeline = AutoPipelineForImage2Image.from_pretrained(
+ self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"]
+ )
+ pipeline.enable_model_cpu_offload(device=torch_device)
+ pipeline.set_progress_bar_config(disable=None)
+
+ inputs = self.get_inputs(torch_device, guidance_scale=0.0, pag_scale=1.8)
+ image = pipeline(**inputs).images
+ image_slice = image[0, -3:, -3:, -1].flatten()
+ assert image.shape == (1, 1024, 1024, 3)
+ expected_slice = np.array(
+ [0.1508789, 0.16210938, 0.17138672, 0.16210938, 0.17089844, 0.16137695, 0.16235352, 0.16430664, 0.16455078]
+ )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py
index ec8cde23c31d..d000493d6bd1 100644
--- a/tests/pipelines/pag/test_pag_sd_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd_img2img.py
@@ -32,10 +32,11 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -210,9 +211,16 @@ def test_pag_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusionPAGImg2ImgPipeline
repo_id = "Jiali/stable-diffusion-1.5"
@@ -220,12 +228,12 @@ class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -247,7 +255,7 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
def test_pag_cfg(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
@@ -255,7 +263,7 @@ def test_pag_cfg(self):
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
- print(image_slice.flatten())
+
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
@@ -265,7 +273,7 @@ def test_pag_cfg(self):
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device, guidance_scale=0.0)
@@ -276,7 +284,7 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- print(image_slice.flatten())
+
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py
new file mode 100644
index 000000000000..06682c111d37
--- /dev/null
+++ b/tests/pipelines/pag/test_pag_sd_inpaint.py
@@ -0,0 +1,324 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import random
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ AutoPipelineForInpainting,
+ PNDMScheduler,
+ StableDiffusionPAGInpaintPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ floats_tensor,
+ load_image,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import (
+ TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
+ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
+ TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
+)
+from ..test_pipelines_common import (
+ IPAdapterTesterMixin,
+ PipelineFromPipeTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class StableDiffusionPAGInpaintPipelineFastTests(
+ PipelineTesterMixin,
+ IPAdapterTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineFromPipeTesterMixin,
+ unittest.TestCase,
+):
+ pipeline_class = StableDiffusionPAGInpaintPipeline
+ params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
+ batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
+ image_params = frozenset([])
+ image_latents_params = frozenset([])
+ callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
+ {"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"}
+ )
+
+ def get_dummy_components(self, time_cond_proj_dim=None):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ time_cond_proj_dim=time_cond_proj_dim,
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ "image_encoder": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
+ # create mask
+ image[8:, 8:, :] = 255
+ mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
+
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": init_image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "strength": 1.0,
+ "pag_scale": 0.9,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_pag_applied_layers(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ # base pipeline
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers
+ all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k]
+ original_attn_procs = pipe.unet.attn_processors
+ pag_layers = [
+ "down",
+ "mid",
+ "up",
+ ]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
+
+ # pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.attentions_0"] should apply to all self-attention layers in mid_block, i.e.
+ # mid_block.attentions.0.transformer_blocks.0.attn1.processor
+ # mid_block.attentions.0.transformer_blocks.1.attn1.processor
+ all_self_attn_mid_layers = [
+ "mid_block.attentions.0.transformer_blocks.0.attn1.processor",
+ # "mid_block.attentions.0.transformer_blocks.1.attn1.processor",
+ ]
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["mid"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["mid_block"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["mid_block.attentions.0"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
+
+ # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["mid_block.attentions.1"]
+ with self.assertRaises(ValueError):
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+
+ # pag_applied_layers = "down" should apply to all self-attention layers in down_blocks
+ # down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
+ # down_blocks.1.attentions.0.transformer_blocks.1.attn1.processor
+ # down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["down"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert len(pipe.pag_attn_processors) == 2
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["down_blocks.0"]
+ with self.assertRaises(ValueError):
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["down_blocks.1"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert len(pipe.pag_attn_processors) == 2
+
+ pipe.unet.set_attn_processor(original_attn_procs.copy())
+ pag_layers = ["down_blocks.1.attentions.1"]
+ pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
+ assert len(pipe.pag_attn_processors) == 1
+
+ def test_pag_inference(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+
+ pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
+ pipe_pag = pipe_pag.to(device)
+ pipe_pag.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe_pag(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (
+ 1,
+ 64,
+ 64,
+ 3,
+ ), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
+
+ expected_slice = np.array([0.7190, 0.5807, 0.6007, 0.5600, 0.6350, 0.6639, 0.5680, 0.5664, 0.5230])
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol=1e-3, rtol=1e-3)
+
+
+@slow
+@require_torch_accelerator
+class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
+ pipeline_class = StableDiffusionPAGInpaintPipeline
+ repo_id = "runwayml/stable-diffusion-v1-5"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0):
+ img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ init_image = load_image(img_url).convert("RGB")
+ mask_image = load_image(mask_url).convert("RGB")
+
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
+ inputs = {
+ "prompt": "A majestic tiger sitting on a bench",
+ "generator": generator,
+ "image": init_image,
+ "mask_image": mask_image,
+ "strength": 0.8,
+ "num_inference_steps": 3,
+ "guidance_scale": guidance_scale,
+ "pag_scale": 3.0,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_pag_cfg(self):
+ pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
+ pipeline.enable_model_cpu_offload(device=torch_device)
+ pipeline.set_progress_bar_config(disable=None)
+
+ inputs = self.get_inputs(torch_device)
+ image = pipeline(**inputs).images
+
+ image_slice = image[0, -3:, -3:, -1].flatten()
+ assert image.shape == (1, 512, 512, 3)
+
+ expected_slice = np.array(
+ [0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625]
+ )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
+
+ def test_pag_uncond(self):
+ pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
+ pipeline.enable_model_cpu_offload(device=torch_device)
+ pipeline.set_progress_bar_config(disable=None)
+
+ inputs = self.get_inputs(torch_device, guidance_scale=0.0)
+ image = pipeline(**inputs).images
+
+ image_slice = image[0, -3:, -3:, -1].flatten()
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array(
+ [0.3876953, 0.40356445, 0.4934082, 0.39697266, 0.41674805, 0.41015625, 0.375, 0.36914062, 0.40649414]
+ )
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ ), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py
index 589573385677..b35b2b1d2f7e 100644
--- a/tests/pipelines/pag/test_pag_sdxl.py
+++ b/tests/pipelines/pag/test_pag_sdxl.py
@@ -30,8 +30,9 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -47,7 +48,6 @@
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -59,7 +59,6 @@ class StableDiffusionXLPAGPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineFromPipeTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLPAGPipeline
@@ -193,9 +192,6 @@ def test_pag_disable_enable(self):
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
def test_pag_applied_layers(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -288,9 +284,13 @@ def test_pag_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLPAGPipelineIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusionXLPAGPipeline
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -298,12 +298,12 @@ class StableDiffusionXLPAGPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -320,7 +320,7 @@ def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0)
def test_pag_cfg(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
@@ -337,7 +337,7 @@ def test_pag_cfg(self):
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device, guidance_scale=0.0)
diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py
index 7e5fc5fa28b9..c94a6836de7f 100644
--- a/tests/pipelines/pag/test_pag_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py
@@ -39,10 +39,11 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -58,7 +59,6 @@
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -70,7 +70,6 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineFromPipeTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLPAGImg2ImgPipeline
@@ -82,6 +81,8 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests(
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
)
+ supports_dduf = False
+
# based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_img2img_pipeline.get_dummy_components
def get_dummy_components(
self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False
@@ -239,9 +240,6 @@ def test_pag_disable_enable(self):
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
def test_pag_inference(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(requires_aesthetics_score=True)
@@ -265,21 +263,25 @@ def test_pag_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0):
img_url = (
@@ -303,7 +305,7 @@ def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0)
def test_pag_cfg(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
@@ -320,7 +322,7 @@ def test_pag_cfg(self):
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device, guidance_scale=0.0)
diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
index efc37abd0682..cca5292288b0 100644
--- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
@@ -40,10 +40,11 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -58,7 +59,6 @@
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -70,7 +70,6 @@ class StableDiffusionXLPAGInpaintPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineFromPipeTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLPAGInpaintPipeline
@@ -82,6 +81,8 @@ class StableDiffusionXLPAGInpaintPipelineFastTests(
{"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"}
)
+ supports_dduf = False
+
# based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components
def get_dummy_components(
self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False
@@ -244,9 +245,6 @@ def test_pag_disable_enable(self):
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
def test_pag_inference(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(requires_aesthetics_score=True)
@@ -270,21 +268,25 @@ def test_pag_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLPAGInpaintPipelineIntegrationTests(unittest.TestCase):
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0):
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
@@ -309,7 +311,7 @@ def get_inputs(self, device, generator_device="cpu", seed=0, guidance_scale=7.0)
def test_pag_cfg(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
@@ -326,7 +328,7 @@ def test_pag_cfg(self):
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
- pipeline.enable_model_cpu_offload()
+ pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device, guidance_scale=0.0)
diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py
index c71e2d4761c2..6b668de2762a 100644
--- a/tests/pipelines/paint_by_example/test_paint_by_example.py
+++ b/tests/pipelines/paint_by_example/test_paint_by_example.py
@@ -46,6 +46,8 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py
index ca558fbb83e5..1156bf32dafa 100644
--- a/tests/pipelines/pia/test_pia.py
+++ b/tests/pipelines/pia/test_pia.py
@@ -18,7 +18,7 @@
UNetMotionModel,
)
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import floats_tensor, torch_device
+from diffusers.utils.testing_utils import floats_tensor, require_accelerator, torch_device
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
@@ -55,6 +55,8 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
"callback_on_step_end_tensor_inputs",
]
)
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
cross_attention_dim = 8
@@ -278,7 +280,7 @@ def test_inference_batch_single_identical(
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -294,14 +296,14 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -436,3 +438,11 @@ def test_xformers_attention_forwardGenerator_pass(self):
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py
index e7039c61a448..ea5cfcef86fd 100644
--- a/tests/pipelines/pixart_alpha/test_pixart.py
+++ b/tests/pipelines/pixart_alpha/test_pixart.py
@@ -28,9 +28,10 @@
PixArtTransformer2DModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -50,6 +51,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -102,85 +105,11 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
+ @unittest.skip("Not supported.")
def test_sequential_cpu_offload_forward_pass(self):
# TODO(PVP, Sayak) need to fix later
return
- def test_save_load_optional_components(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
-
- prompt = inputs["prompt"]
- generator = inputs["generator"]
- num_inference_steps = inputs["num_inference_steps"]
- output_type = inputs["output_type"]
-
- (
- prompt_embeds,
- prompt_attention_mask,
- negative_prompt_embeds,
- negative_prompt_attention_mask,
- ) = pipe.encode_prompt(prompt)
-
- # inputs with prompt converted to embeddings
- inputs = {
- "prompt_embeds": prompt_embeds,
- "prompt_attention_mask": prompt_attention_mask,
- "negative_prompt": None,
- "negative_prompt_embeds": negative_prompt_embeds,
- "negative_prompt_attention_mask": negative_prompt_attention_mask,
- "generator": generator,
- "num_inference_steps": num_inference_steps,
- "output_type": output_type,
- "use_resolution_binning": False,
- }
-
- # set all optional components to None
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
-
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
-
- inputs = self.get_dummy_inputs(torch_device)
-
- generator = inputs["generator"]
- num_inference_steps = inputs["num_inference_steps"]
- output_type = inputs["output_type"]
-
- # inputs with prompt converted to embeddings
- inputs = {
- "prompt_embeds": prompt_embeds,
- "prompt_attention_mask": prompt_attention_mask,
- "negative_prompt": None,
- "negative_prompt_embeds": negative_prompt_embeds,
- "negative_prompt_attention_mask": negative_prompt_attention_mask,
- "generator": generator,
- "num_inference_steps": num_inference_steps,
- "output_type": output_type,
- "use_resolution_binning": False,
- }
-
- output_loaded = pipe_loaded(**inputs)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(max_diff, 1e-4)
-
def test_inference(self):
device = "cpu"
@@ -215,6 +144,10 @@ def test_inference_non_square_images(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ @unittest.skip("Test is already covered through encode_prompt isolation.")
+ def test_save_load_optional_components(self):
+ pass
+
def test_inference_with_embeddings_and_multiple_images(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -326,7 +259,7 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
ckpt_id_1024 = "PixArt-alpha/PixArt-XL-2-1024-MS"
ckpt_id_512 = "PixArt-alpha/PixArt-XL-2-512x512"
@@ -335,18 +268,18 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_pixart_1024(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
@@ -361,7 +294,7 @@ def test_pixart_512(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
@@ -377,7 +310,7 @@ def test_pixart_1024_without_resolution_binning(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
height, width = 1024, 768
@@ -411,7 +344,7 @@ def test_pixart_512_without_resolution_binning(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
height, width = 512, 768
diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py
index a92e99366ee3..b220afcfc25a 100644
--- a/tests/pipelines/pixart_sigma/test_pixart.py
+++ b/tests/pipelines/pixart_sigma/test_pixart.py
@@ -28,9 +28,10 @@
PixArtTransformer2DModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -55,6 +56,8 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -107,85 +110,11 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
+ @unittest.skip("Not supported.")
def test_sequential_cpu_offload_forward_pass(self):
# TODO(PVP, Sayak) need to fix later
return
- def test_save_load_optional_components(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
-
- prompt = inputs["prompt"]
- generator = inputs["generator"]
- num_inference_steps = inputs["num_inference_steps"]
- output_type = inputs["output_type"]
-
- (
- prompt_embeds,
- prompt_attention_mask,
- negative_prompt_embeds,
- negative_prompt_attention_mask,
- ) = pipe.encode_prompt(prompt)
-
- # inputs with prompt converted to embeddings
- inputs = {
- "prompt_embeds": prompt_embeds,
- "prompt_attention_mask": prompt_attention_mask,
- "negative_prompt": None,
- "negative_prompt_embeds": negative_prompt_embeds,
- "negative_prompt_attention_mask": negative_prompt_attention_mask,
- "generator": generator,
- "num_inference_steps": num_inference_steps,
- "output_type": output_type,
- "use_resolution_binning": False,
- }
-
- # set all optional components to None
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
-
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
-
- inputs = self.get_dummy_inputs(torch_device)
-
- generator = inputs["generator"]
- num_inference_steps = inputs["num_inference_steps"]
- output_type = inputs["output_type"]
-
- # inputs with prompt converted to embeddings
- inputs = {
- "prompt_embeds": prompt_embeds,
- "prompt_attention_mask": prompt_attention_mask,
- "negative_prompt": None,
- "negative_prompt_embeds": negative_prompt_embeds,
- "negative_prompt_attention_mask": negative_prompt_attention_mask,
- "generator": generator,
- "num_inference_steps": num_inference_steps,
- "output_type": output_type,
- "use_resolution_binning": False,
- }
-
- output_loaded = pipe_loaded(**inputs)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(max_diff, 1e-4)
-
def test_inference(self):
device = "cpu"
@@ -310,6 +239,10 @@ def test_inference_with_multiple_images_per_prompt(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ @unittest.skip("Test is already covered through encode_prompt isolation.")
+ def test_save_load_optional_components(self):
+ pass
+
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
@@ -355,7 +288,7 @@ def test_fused_qkv_projections(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
ckpt_id_1024 = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
ckpt_id_512 = "PixArt-alpha/PixArt-Sigma-XL-2-512-MS"
@@ -364,18 +297,18 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_pixart_1024(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = PixArtSigmaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
@@ -395,7 +328,7 @@ def test_pixart_512(self):
pipe = PixArtSigmaPipeline.from_pretrained(
self.ckpt_id_1024, transformer=transformer, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
@@ -411,7 +344,7 @@ def test_pixart_1024_without_resolution_binning(self):
generator = torch.manual_seed(0)
pipe = PixArtSigmaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
height, width = 1024, 768
@@ -450,7 +383,7 @@ def test_pixart_512_without_resolution_binning(self):
pipe = PixArtSigmaPipeline.from_pretrained(
self.ckpt_id_1024, transformer=transformer, torch_dtype=torch.float16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
height, width = 512, 768
diff --git a/tests/pipelines/sana/__init__.py b/tests/pipelines/sana/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py
new file mode 100644
index 000000000000..aa5d5c7ce463
--- /dev/null
+++ b/tests/pipelines/sana/test_sana.py
@@ -0,0 +1,373 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = SanaTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ out_channels=4,
+ num_layers=1,
+ num_attention_heads=2,
+ attention_head_dim=4,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=4,
+ cross_attention_dim=8,
+ caption_channels=8,
+ sample_size=32,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderDC(
+ in_channels=3,
+ latent_channels=4,
+ attention_head_dim=2,
+ encoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ decoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ encoder_block_out_channels=(8, 8),
+ decoder_block_out_channels=(8, 8),
+ encoder_qkv_multiscales=((), (5,)),
+ decoder_qkv_multiscales=((), (5,)),
+ encoder_layers_per_block=(1, 1),
+ decoder_layers_per_block=[1, 1],
+ downsample_block_type="conv",
+ upsample_block_type="interpolate",
+ decoder_norm_types="rms_norm",
+ decoder_act_fns="silu",
+ scaling_factor=0.41407,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": None,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.randn(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
+
+
+@slow
+@require_torch_accelerator
+class SanaPipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_sana_1024(self):
+ generator = torch.Generator("cpu").manual_seed(0)
+
+ pipe = SanaPipeline.from_pretrained(
+ "Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload(device=torch_device)
+
+ image = pipe(
+ prompt=self.prompt,
+ height=1024,
+ width=1024,
+ generator=generator,
+ num_inference_steps=20,
+ output_type="np",
+ ).images[0]
+
+ image = image.flatten()
+ output_slice = np.concatenate((image[:16], image[-16:]))
+
+ # fmt: off
+ expected_slice = np.array([0.0427, 0.0789, 0.0662, 0.0464, 0.082, 0.0574, 0.0535, 0.0886, 0.0647, 0.0549, 0.0872, 0.0605, 0.0593, 0.0942, 0.0674, 0.0581, 0.0076, 0.0168, 0.0027, 0.0063, 0.0159, 0.0, 0.0071, 0.0198, 0.0034, 0.0105, 0.0212, 0.0, 0.0, 0.0166, 0.0042, 0.0125])
+ # fmt: on
+
+ self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-4))
+
+ def test_sana_512(self):
+ generator = torch.Generator("cpu").manual_seed(0)
+
+ pipe = SanaPipeline.from_pretrained(
+ "Efficient-Large-Model/Sana_1600M_512px_diffusers", torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload(device=torch_device)
+
+ image = pipe(
+ prompt=self.prompt,
+ height=512,
+ width=512,
+ generator=generator,
+ num_inference_steps=20,
+ output_type="np",
+ ).images[0]
+
+ image = image.flatten()
+ output_slice = np.concatenate((image[:16], image[-16:]))
+
+ # fmt: off
+ expected_slice = np.array([0.0803, 0.0774, 0.1108, 0.0872, 0.093, 0.1118, 0.0952, 0.0898, 0.1038, 0.0818, 0.0754, 0.0894, 0.074, 0.0691, 0.0906, 0.0671, 0.0154, 0.0254, 0.0203, 0.0178, 0.0283, 0.0193, 0.0215, 0.0273, 0.0188, 0.0212, 0.0273, 0.0151, 0.0061, 0.0244, 0.0212, 0.0259])
+ # fmt: on
+
+ self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-4))
diff --git a/tests/pipelines/sana/test_sana_sprint.py b/tests/pipelines/sana/test_sana_sprint.py
new file mode 100644
index 000000000000..d006c2b986ca
--- /dev/null
+++ b/tests/pipelines/sana/test_sana_sprint.py
@@ -0,0 +1,302 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaSprintPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt", "negative_prompt_embeds"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"}
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"}
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = SanaTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ out_channels=4,
+ num_layers=1,
+ num_attention_heads=2,
+ attention_head_dim=4,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=4,
+ cross_attention_dim=8,
+ caption_channels=8,
+ sample_size=32,
+ qk_norm="rms_norm_across_heads",
+ guidance_embeds=True,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderDC(
+ in_channels=3,
+ latent_channels=4,
+ attention_head_dim=2,
+ encoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ decoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ encoder_block_out_channels=(8, 8),
+ decoder_block_out_channels=(8, 8),
+ encoder_qkv_multiscales=((), (5,)),
+ decoder_qkv_multiscales=((), (5,)),
+ encoder_layers_per_block=(1, 1),
+ decoder_layers_per_block=[1, 1],
+ downsample_block_type="conv",
+ upsample_block_type="interpolate",
+ decoder_norm_types="rms_norm",
+ decoder_act_fns="silu",
+ scaling_factor=0.41407,
+ )
+
+ torch.manual_seed(0)
+ scheduler = SCMScheduler()
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": None,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.randn(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
index 990c389a9c5f..6cd431f02d58 100644
--- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
+++ b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
@@ -28,6 +28,7 @@
enable_full_determinism,
floats_tensor,
nightly,
+ require_accelerator,
require_torch_gpu,
torch_device,
)
@@ -237,7 +238,7 @@ def test_semantic_diffusion_no_safety_checker(self):
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
- @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
+ @require_accelerator
def test_semantic_diffusion_fp16(self):
"""Test that stable diffusion works with fp16"""
unet = self.dummy_cond_unet
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index f3661355e9dd..ac7096874b31 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -50,6 +50,8 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
test_xformers_attention = False
+ supports_dduf = False
+
@property
def text_embedder_hidden_size(self):
return 16
diff --git a/tests/pipelines/stable_audio/test_stable_audio.py b/tests/pipelines/stable_audio/test_stable_audio.py
index 41ac94891c6f..01df82056ce2 100644
--- a/tests/pipelines/stable_audio/test_stable_audio.py
+++ b/tests/pipelines/stable_audio/test_stable_audio.py
@@ -70,6 +70,7 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
# There is not xformers version of the StableAudioPipeline custom attention processor
test_xformers_attention = False
+ supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
@@ -412,6 +413,10 @@ def test_sequential_cpu_offload_forward_pass(self):
def test_sequential_offload_forward_pass_twice(self):
pass
+ @unittest.skip("Test not supported because `rotary_embed_dim` doesn't have any sensible default.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
@nightly
@require_torch_gpu
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
index d256deed376c..1765f3a02242 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
@@ -22,7 +22,7 @@
from diffusers import DDPMWuerstchenScheduler, StableCascadeCombinedPipeline
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
@@ -205,7 +205,7 @@ def test_stable_cascade(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -214,12 +214,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -242,40 +242,3 @@ def test_float16_inference(self):
@unittest.skip(reason="no callback test for combined pipeline")
def test_callback_inputs(self):
super().test_callback_inputs()
-
- def test_stable_cascade_combined_prompt_embeds(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = StableCascadeCombinedPipeline(**components)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "A photograph of a shiba inu, wearing a hat"
- (
- prompt_embeds,
- prompt_embeds_pooled,
- negative_prompt_embeds,
- negative_prompt_embeds_pooled,
- ) = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
- generator = torch.Generator(device=device)
-
- output_prompt = pipe(
- prompt=prompt,
- num_inference_steps=1,
- prior_num_inference_steps=1,
- output_type="np",
- generator=generator.manual_seed(0),
- )
- output_prompt_embeds = pipe(
- prompt=None,
- prompt_embeds=prompt_embeds,
- prompt_embeds_pooled=prompt_embeds_pooled,
- negative_prompt_embeds=negative_prompt_embeds,
- negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
- num_inference_steps=1,
- prior_num_inference_steps=1,
- output_type="np",
- generator=generator.manual_seed(0),
- )
-
- assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
index 07e4244e3c68..afcd8fca71ca 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
@@ -24,11 +24,12 @@
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_numpy,
load_pt,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
slow,
torch_device,
@@ -208,45 +209,6 @@ def test_attention_slicing_forward_pass(self):
def test_float16_inference(self):
super().test_float16_inference()
- def test_stable_cascade_decoder_prompt_embeds(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = StableCascadeDecoderPipeline(**components)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image_embeddings = inputs["image_embeddings"]
- prompt = "A photograph of a shiba inu, wearing a hat"
- (
- prompt_embeds,
- prompt_embeds_pooled,
- negative_prompt_embeds,
- negative_prompt_embeds_pooled,
- ) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
- generator = torch.Generator(device=device)
-
- decoder_output_prompt = pipe(
- image_embeddings=image_embeddings,
- prompt=prompt,
- num_inference_steps=1,
- output_type="np",
- generator=generator.manual_seed(0),
- )
- decoder_output_prompt_embeds = pipe(
- image_embeddings=image_embeddings,
- prompt=None,
- prompt_embeds=prompt_embeds,
- prompt_embeds_pooled=prompt_embeds_pooled,
- negative_prompt_embeds=negative_prompt_embeds,
- negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
- num_inference_steps=1,
- output_type="np",
- generator=generator.manual_seed(0),
- )
-
- assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
-
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self):
device = "cpu"
components = self.get_dummy_components()
@@ -307,27 +269,35 @@ def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_gui
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "batch_size": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_cascade_decoder(self):
pipe = StableCascadeDecoderPipeline.from_pretrained(
"stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
index 0208224a1d80..0374de9b0219 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
@@ -24,11 +24,12 @@
from diffusers.models import StableCascadeUNet
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_numpy,
numpy_cosine_similarity_distance,
require_peft_backend,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
slow,
torch_device,
@@ -240,62 +241,31 @@ def test_inference_with_prior_lora(self):
self.assertTrue(image_embed.shape == lora_image_embed.shape)
- def test_stable_cascade_decoder_prompt_embeds(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "A photograph of a shiba inu, wearing a hat"
- (
- prompt_embeds,
- prompt_embeds_pooled,
- negative_prompt_embeds,
- negative_prompt_embeds_pooled,
- ) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
- generator = torch.Generator(device=device)
-
- output_prompt = pipe(
- prompt=prompt,
- num_inference_steps=1,
- output_type="np",
- generator=generator.manual_seed(0),
- )
- output_prompt_embeds = pipe(
- prompt=None,
- prompt_embeds=prompt_embeds,
- prompt_embeds_pooled=prompt_embeds_pooled,
- negative_prompt_embeds=negative_prompt_embeds,
- negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
- num_inference_steps=1,
- output_type="np",
- generator=generator.manual_seed(0),
- )
-
- assert np.abs(output_prompt.image_embeddings - output_prompt_embeds.image_embeddings).max() < 1e-5
+ @unittest.skip("Test not supported because dtype determination relies on text encoder.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableCascadePriorPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_cascade_prior(self):
pipe = StableCascadePriorPipeline.from_pretrained(
"stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index f37d598c8387..6e17b86639ea 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -44,6 +44,10 @@
)
from diffusers.utils.testing_utils import (
CaptureLogger,
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
is_torch_compile,
load_image,
@@ -52,8 +56,8 @@
numpy_cosine_similarity_distance,
require_accelerate_version_greater,
require_torch_2,
- require_torch_gpu,
- require_torch_multi_gpu,
+ require_torch_accelerator,
+ require_torch_multi_accelerator,
run_test_in_subprocess,
skip_mps,
slow,
@@ -123,6 +127,8 @@ class StableDiffusionPipelineFastTests(
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
cross_attention_dim = 8
@@ -373,84 +379,6 @@ def test_stable_diffusion_negative_prompt_embeds(self):
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
- def test_stable_diffusion_prompt_embeds_no_text_encoder_or_tokenizer(self):
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "this is a negative prompt"
-
- # forward
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
- negative_prompt = "this is a negative prompt"
-
- prompt_embeds, negative_prompt_embeds = sd_pipe.encode_prompt(
- prompt,
- torch_device,
- 1,
- True,
- negative_prompt=negative_prompt,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- )
-
- inputs["prompt_embeds"] = prompt_embeds
- inputs["negative_prompt_embeds"] = negative_prompt_embeds
-
- sd_pipe.text_encoder = None
- sd_pipe.tokenizer = None
-
- # forward
- output = sd_pipe(**inputs)
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- def test_stable_diffusion_prompt_embeds_with_plain_negative_prompt_list(self):
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = negative_prompt
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = sd_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=sd_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = sd_pipe.text_encoder(text_inputs)[0]
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = sd_pipe(**inputs)
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -840,13 +768,28 @@ def callback_on_step_end(pipe, i, t, callback_kwargs):
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
+ def test_pipeline_accept_tuple_type_unet_sample_size(self):
+ # the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
+ sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+ sample_size = [60, 80]
+ customised_unet = UNet2DConditionModel(sample_size=sample_size)
+ pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet)
+ assert pipe.unet.config.sample_size == sample_size
+
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionPipelineSlowTests(unittest.TestCase):
def setUp(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -948,7 +891,7 @@ def test_stable_diffusion_dpm(self):
assert np.abs(image_slice - expected_slice).max() < 3e-3
def test_stable_diffusion_attention_slicing(self):
- torch.cuda.reset_peak_memory_stats()
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.unet.set_default_attn_processor()
pipe = pipe.to(torch_device)
@@ -959,8 +902,8 @@ def test_stable_diffusion_attention_slicing(self):
inputs = self.get_inputs(torch_device, dtype=torch.float16)
image_sliced = pipe(**inputs).images
- mem_bytes = torch.cuda.max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ mem_bytes = backend_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
# make sure that less than 3.75 GB is allocated
assert mem_bytes < 3.75 * 10**9
@@ -971,13 +914,13 @@ def test_stable_diffusion_attention_slicing(self):
image = pipe(**inputs).images
# make sure that more than 3.75 GB is allocated
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes > 3.75 * 10**9
max_diff = numpy_cosine_similarity_distance(image_sliced.flatten(), image.flatten())
assert max_diff < 1e-3
def test_stable_diffusion_vae_slicing(self):
- torch.cuda.reset_peak_memory_stats()
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -990,8 +933,8 @@ def test_stable_diffusion_vae_slicing(self):
inputs["latents"] = torch.cat([inputs["latents"]] * 4)
image_sliced = pipe(**inputs).images
- mem_bytes = torch.cuda.max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ mem_bytes = backend_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
# make sure that less than 4 GB is allocated
assert mem_bytes < 4e9
@@ -1003,14 +946,14 @@ def test_stable_diffusion_vae_slicing(self):
image = pipe(**inputs).images
# make sure that more than 4 GB is allocated
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes > 4e9
# There is a small discrepancy at the image borders vs. a fully batched version.
max_diff = numpy_cosine_similarity_distance(image_sliced.flatten(), image.flatten())
assert max_diff < 1e-2
def test_stable_diffusion_vae_tiling(self):
- torch.cuda.reset_peak_memory_stats()
+ backend_reset_peak_memory_stats(torch_device)
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, variant="fp16", torch_dtype=torch.float16, safety_checker=None
@@ -1024,7 +967,7 @@ def test_stable_diffusion_vae_tiling(self):
# enable vae tiling
pipe.enable_vae_tiling()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output_chunked = pipe(
[prompt],
@@ -1037,7 +980,7 @@ def test_stable_diffusion_vae_tiling(self):
)
image_chunked = output_chunked.images
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# disable vae tiling
pipe.disable_vae_tiling()
@@ -1130,26 +1073,25 @@ def test_stable_diffusion_low_cpu_mem_usage(self):
assert 2 * low_cpu_mem_usage_time < normal_load_time
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.8 GB is allocated
assert mem_bytes < 2.8 * 10**9
def test_stable_diffusion_pipeline_with_model_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
@@ -1163,7 +1105,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# With model offloading
@@ -1174,16 +1116,16 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
)
pipe.unet.set_default_attn_processor()
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
outputs_offloaded = pipe(**inputs)
- mem_bytes_offloaded = torch.cuda.max_memory_allocated()
+ mem_bytes_offloaded = backend_max_memory_allocated(torch_device)
images = outputs.images
offloaded_images = outputs_offloaded.images
@@ -1196,13 +1138,13 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
assert module.device == torch.device("cpu")
# With attention slicing
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipe.enable_attention_slicing()
_ = pipe(**inputs)
- mem_bytes_slicing = torch.cuda.max_memory_allocated()
+ mem_bytes_slicing = backend_max_memory_allocated(torch_device)
assert mem_bytes_slicing < mem_bytes_offloaded
assert mem_bytes_slicing < 3 * 10**9
@@ -1217,7 +1159,7 @@ def test_stable_diffusion_textual_inversion(self):
)
pipe.load_textual_inversion(a111_file)
pipe.load_textual_inversion(a111_file_neg)
- pipe.to("cuda")
+ pipe.to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(1)
@@ -1234,7 +1176,7 @@ def test_stable_diffusion_textual_inversion(self):
def test_stable_diffusion_textual_inversion_with_model_cpu_offload(self):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons")
a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt")
@@ -1259,8 +1201,8 @@ def test_stable_diffusion_textual_inversion_with_model_cpu_offload(self):
def test_stable_diffusion_textual_inversion_with_sequential_cpu_offload(self):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- pipe.enable_sequential_cpu_offload()
- pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons")
+ pipe.enable_sequential_cpu_offload(device=torch_device)
+ pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons").to(torch_device)
a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt")
a111_file_neg = hf_hub_download(
@@ -1318,17 +1260,17 @@ def test_stable_diffusion_lcm(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionPipelineCkptTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_download_from_hub(self):
ckpt_paths = [
@@ -1339,7 +1281,7 @@ def test_download_from_hub(self):
for ckpt_path in ckpt_paths:
pipe = StableDiffusionPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.to("cuda")
+ pipe.to(torch_device)
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
@@ -1355,7 +1297,7 @@ def test_download_local(self):
ckpt_filename, config_files={"v1": config_filename}, torch_dtype=torch.float16
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.to("cuda")
+ pipe.to(torch_device)
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
@@ -1363,17 +1305,17 @@ def test_download_local(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -1467,13 +1409,13 @@ def test_stable_diffusion_euler(self):
# (sayakpaul): This test suite was run in the DGX with two GPUs (1, 2).
@slow
-@require_torch_multi_gpu
+@require_torch_multi_accelerator
@require_accelerate_version_greater("0.27.0")
class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, generator_device="cpu", seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -1555,7 +1497,7 @@ def test_reset_device_map_to(self):
assert sd_pipe_with_device_map.hf_device_map is None
# Make sure `to()` can be used and the pipeline can be called.
- pipe = sd_pipe_with_device_map.to("cuda")
+ pipe = sd_pipe_with_device_map.to(torch_device)
_ = pipe("hello", num_inference_steps=2)
def test_reset_device_map_enable_model_cpu_offload(self):
@@ -1567,7 +1509,7 @@ def test_reset_device_map_enable_model_cpu_offload(self):
assert sd_pipe_with_device_map.hf_device_map is None
# Make sure `enable_model_cpu_offload()` can be used and the pipeline can be called.
- sd_pipe_with_device_map.enable_model_cpu_offload()
+ sd_pipe_with_device_map.enable_model_cpu_offload(device=torch_device)
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
def test_reset_device_map_enable_sequential_cpu_offload(self):
@@ -1579,5 +1521,5 @@ def test_reset_device_map_enable_sequential_cpu_offload(self):
assert sd_pipe_with_device_map.hf_device_map is None
# Make sure `enable_sequential_cpu_offload()` can be used and the pipeline can be called.
- sd_pipe_with_device_map.enable_sequential_cpu_offload()
+ sd_pipe_with_device_map.enable_sequential_cpu_offload(device=torch_device)
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
index 7ba0bb5a4a5d..82b01a74869a 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
@@ -35,6 +35,10 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
is_torch_compile,
@@ -42,7 +46,7 @@
load_numpy,
nightly,
require_torch_2,
- require_torch_gpu,
+ require_torch_accelerator,
run_test_in_subprocess,
skip_mps,
slow,
@@ -391,19 +395,26 @@ def callback_on_step_end(pipe, i, t, callback_kwargs):
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -506,28 +517,28 @@ def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
assert number_of_steps == 2
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9
def test_stable_diffusion_pipeline_with_model_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
@@ -541,7 +552,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe(**inputs)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# With model offloading
@@ -552,14 +563,14 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
torch_dtype=torch.float16,
)
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
_ = pipe(**inputs)
- mem_bytes_offloaded = torch.cuda.max_memory_allocated()
+ mem_bytes_offloaded = backend_max_memory_allocated(torch_device)
assert mem_bytes_offloaded < mem_bytes
for module in pipe.text_encoder, pipe.unet, pipe.vae:
@@ -656,17 +667,17 @@ def test_img2img_compile(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
index ff04ea2cfc5d..e21cf23b8cbf 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
@@ -37,6 +37,10 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
is_torch_compile,
@@ -44,7 +48,7 @@
load_numpy,
nightly,
require_torch_2,
- require_torch_gpu,
+ require_torch_accelerator,
run_test_in_subprocess,
slow,
torch_device,
@@ -394,6 +398,13 @@ def test_ip_adapter(self, from_simple=False, expected_pipe_slice=None):
)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol=1e-3, rtol=1e-3)
+
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
pipeline_class = StableDiffusionInpaintPipeline
@@ -595,7 +606,7 @@ def test_stable_diffusion_inpaint_euler(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
@@ -603,7 +614,7 @@ def setUp(self):
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -697,21 +708,21 @@ def test_stable_diffusion_inpaint_k_lms(self):
assert np.abs(expected_slice - image_slice).max() < 6e-3
def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"botp/stable-diffusion-v1-5-inpainting", safety_checker=None, torch_dtype=torch.float16
)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9
@@ -786,7 +797,7 @@ def test_stable_diffusion_simple_inpaint_ddim(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionInpaintPipelineAsymmetricAutoencoderKLSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
@@ -794,7 +805,7 @@ def setUp(self):
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -900,9 +911,9 @@ def test_stable_diffusion_inpaint_k_lms(self):
assert np.abs(expected_slice - image_slice).max() < 6e-3
def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
vae = AsymmetricAutoencoderKL.from_pretrained(
"cross-attention/asymmetric-autoencoder-kl-x-1-5", torch_dtype=torch.float16
@@ -913,12 +924,12 @@ def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
pipe.vae = vae
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.45 GB is allocated
assert mem_bytes < 2.45 * 10**9
@@ -1002,7 +1013,7 @@ def test_download_local(self):
pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
pipe.vae = vae
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.to("cuda")
+ pipe.to(torch_device)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 1
@@ -1012,17 +1023,17 @@ def test_download_local(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
index b9b061c060c0..9721bb02ee3e 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
@@ -33,10 +33,14 @@
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -206,9 +210,6 @@ def test_stable_diffusion_pix2pix_euler(self):
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
- slice = [round(x, 4) for x in image_slice.flatten().tolist()]
- print(",".join([str(x) for x in slice]))
-
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.7417, 0.3842, 0.4732, 0.5776, 0.5891, 0.5139, 0.4052, 0.5673, 0.4986])
@@ -269,17 +270,17 @@ def callback_no_cfg(pipe, i, t, callback_kwargs):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionInstructPix2PixPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, seed=0):
generator = torch.manual_seed(seed)
@@ -387,21 +388,21 @@ def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
assert number_of_steps == 3
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", safety_checker=None, torch_dtype=torch.float16
)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_inputs()
_ = pipe(**inputs)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
index e7114d19e208..3f9f7e965b40 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
@@ -34,12 +34,13 @@
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
require_torch_accelerator,
- require_torch_gpu,
skip_mps,
slow,
torch_device,
@@ -75,6 +76,8 @@ class StableDiffusion2PipelineFastTests(
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -310,6 +313,13 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
@require_torch_accelerator
@@ -321,9 +331,8 @@ def tearDown(self):
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- _generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
if not str(device).startswith("mps"):
- generator = torch.Generator(device=_generator_device).manual_seed(seed)
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
else:
generator = torch.manual_seed(seed)
@@ -352,9 +361,9 @@ def test_stable_diffusion_default_ddim(self):
expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506])
assert np.abs(image_slice - expected_slice).max() < 7e-3
- @require_torch_gpu
+ @require_torch_accelerator
def test_stable_diffusion_attention_slicing(self):
- torch.cuda.reset_peak_memory_stats()
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-base", torch_dtype=torch.float16
)
@@ -367,8 +376,8 @@ def test_stable_diffusion_attention_slicing(self):
inputs = self.get_inputs(torch_device, dtype=torch.float16)
image_sliced = pipe(**inputs).images
- mem_bytes = torch.cuda.max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ mem_bytes = backend_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
# make sure that less than 3.3 GB is allocated
assert mem_bytes < 3.3 * 10**9
@@ -379,7 +388,7 @@ def test_stable_diffusion_attention_slicing(self):
image = pipe(**inputs).images
# make sure that more than 3.3 GB is allocated
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes > 3.3 * 10**9
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_sliced.flatten())
assert max_diff < 5e-3
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
index 4c2b3a3c1e85..c66491b15c66 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
@@ -30,7 +30,7 @@
load_numpy,
nightly,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
torch_device,
)
@@ -204,8 +204,15 @@ def test_karras_schedulers_shape(self):
def test_from_pipe_consistent_forward_pass_cpu_offload(self):
super().test_from_pipe_consistent_forward_pass_cpu_offload(expected_max_diff=5e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
-@require_torch_gpu
+@require_torch_accelerator
@nightly
class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
# Attend and excite requires being able to run a backward pass at
@@ -237,7 +244,7 @@ def test_attend_and_excite_fp16(self):
pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
- pipe.to("cuda")
+ pipe.to(torch_device)
prompt = "a painting of an elephant with glasses"
token_indices = [5, 7]
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
index 42eef061069e..0a0051816162 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
@@ -36,14 +36,16 @@
StableDiffusionDepth2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils import is_accelerate_available, is_accelerate_version
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
nightly,
- require_torch_gpu,
+ require_accelerate_version_greater,
+ require_accelerator,
+ require_torch_accelerator,
skip_mps,
slow,
torch_device,
@@ -75,6 +77,8 @@ class StableDiffusionDepth2ImgPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"depth_mask"})
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -194,7 +198,8 @@ def test_save_load_local(self):
max_diff = np.abs(output - output_loaded).max()
self.assertLess(max_diff, 1e-4)
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self):
components = self.get_dummy_components()
for name, module in components.items():
@@ -226,7 +231,8 @@ def test_save_load_float16(self):
max_diff = np.abs(output - output_loaded).max()
self.assertLess(max_diff, 2e-2, "The output of the fp16 pipeline changed after saving and loading.")
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_float16_inference(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -246,10 +252,8 @@ def test_float16_inference(self):
max_diff = np.abs(output - output_fp16).max()
self.assertLess(max_diff, 1.3e-2, "The outputs of the fp16 and fp32 pipelines are too different.")
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.14.0")
def test_cpu_offload_forward_pass(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -259,7 +263,7 @@ def test_cpu_offload_forward_pass(self):
inputs = self.get_dummy_inputs(torch_device)
output_without_offload = pipe(**inputs)[0]
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = pipe(**inputs)[0]
@@ -366,19 +370,26 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=7e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionDepth2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
@@ -415,17 +426,17 @@ def test_stable_diffusion_depth2img_pipeline_default(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionImg2ImgPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
index 1cb03ddd96d7..34ea56664a95 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
@@ -33,12 +33,13 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
nightly,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
torch_device,
)
@@ -291,19 +292,26 @@ def test_inversion_dpm(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
-@require_torch_gpu
+@require_torch_accelerator
@nightly
class StableDiffusionDiffEditPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@classmethod
def setUpClass(cls):
@@ -324,7 +332,7 @@ def test_stable_diffusion_diffedit_full(self):
pipe.scheduler.clip_sample = True
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
source_prompt = "a bowl of fruit"
@@ -370,17 +378,17 @@ def test_stable_diffusion_diffedit_full(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionDiffEditPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@classmethod
def setUpClass(cls):
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
index dc855f44b817..9e4fa767085f 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
@@ -62,7 +62,7 @@ def test_stable_diffusion_flax(self):
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
@@ -104,5 +104,5 @@ def test_stable_diffusion_dpm_flax(self):
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
index 8f039980ec24..eeec52dab51d 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
@@ -78,5 +78,5 @@ def test_stable_diffusion_inpaint_pipeline(self):
expected_slice = jnp.array(
[0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084]
)
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
index b99a1816456e..2feeaaf11c12 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
@@ -24,11 +24,14 @@
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -152,21 +155,28 @@ def test_stable_diffusion_inpaint(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_diffusion_inpaint_pipeline(self):
init_image = load_image(
@@ -241,9 +251,9 @@ def test_stable_diffusion_inpaint_pipeline_fp16(self):
assert np.abs(expected_image - image).max() < 5e-1
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
@@ -263,7 +273,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
index 134175bdaffe..22e588a9327b 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
@@ -31,11 +31,12 @@
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -279,30 +280,34 @@ def test_karras_schedulers_shape(self):
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)
+ @unittest.skip("Test not supported for a weird use of `text_input_ids`.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
-@require_torch_gpu
+
+@require_torch_accelerator
@slow
class StableDiffusionLatentUpscalePipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_latent_upscaler_fp16(self):
generator = torch.manual_seed(33)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
- pipe.to("cuda")
+ pipe.to(torch_device)
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16
)
- upscaler.to("cuda")
+ upscaler.to(torch_device)
prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic"
@@ -328,7 +333,7 @@ def test_latent_upscaler_fp16_image(self):
upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16
)
- upscaler.to("cuda")
+ upscaler.to(torch_device)
prompt = "the temple of fire by Ross Tran and Gerardo Dottori, oil on canvas"
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
index c21da7af6d2c..5400c21c9f87 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
@@ -25,11 +25,16 @@
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
- require_torch_gpu,
+ require_accelerator,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -43,13 +48,13 @@ def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@property
def dummy_image(self):
@@ -289,7 +294,7 @@ def test_stable_diffusion_upscale_prompt_embeds(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_prompt_embeds_slice.flatten() - expected_slice).max() < 1e-2
- @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
+ @require_accelerator
def test_stable_diffusion_upscale_fp16(self):
"""Test that stable diffusion upscale works with fp16"""
unet = self.dummy_cond_unet_upscale
@@ -380,19 +385,19 @@ def test_stable_diffusion_upscale_from_save_pretrained(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_diffusion_upscale_pipeline(self):
image = load_image(
@@ -458,9 +463,9 @@ def test_stable_diffusion_upscale_pipeline_fp16(self):
assert np.abs(expected_image - image).max() < 5e-1
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
@@ -474,7 +479,7 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
prompt = "a cat sitting on a park bench"
@@ -487,6 +492,6 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
output_type="np",
)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.9 GB is allocated
assert mem_bytes < 2.9 * 10**9
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
index 703c3b7a39d8..1953017c0ee8 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
@@ -31,10 +31,15 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
load_numpy,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_accelerator,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -48,13 +53,13 @@ def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@property
def dummy_cond_unet(self):
@@ -213,7 +218,7 @@ def test_stable_diffusion_v_pred_k_euler(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
+ @require_accelerator
def test_stable_diffusion_v_pred_fp16(self):
"""Test that stable diffusion v-prediction works with fp16"""
unet = self.dummy_cond_unet
@@ -257,19 +262,19 @@ def test_stable_diffusion_v_pred_fp16(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_diffusion_v_pred_default(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
@@ -356,7 +361,7 @@ def test_stable_diffusion_v_pred_dpm(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_attention_slicing_v_pred(self):
- torch.cuda.reset_peak_memory_stats()
+ backend_reset_peak_memory_stats(torch_device)
model_id = "stabilityai/stable-diffusion-2"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to(torch_device)
@@ -372,8 +377,8 @@ def test_stable_diffusion_attention_slicing_v_pred(self):
)
image_chunked = output_chunked.images
- mem_bytes = torch.cuda.max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ mem_bytes = backend_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
# make sure that less than 5.5 GB is allocated
assert mem_bytes < 5.5 * 10**9
@@ -384,7 +389,7 @@ def test_stable_diffusion_attention_slicing_v_pred(self):
image = output.images
# make sure that more than 3.0 GB is allocated
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes > 3 * 10**9
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_chunked.flatten())
assert max_diff < 1e-3
@@ -420,7 +425,7 @@ def test_stable_diffusion_text2img_pipeline_unflawed(self):
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing", rescale_betas_zero_snr=True
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
@@ -465,7 +470,7 @@ def test_download_local(self):
pipe = StableDiffusionPipeline.from_single_file(filename, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
@@ -529,20 +534,20 @@ def test_stable_diffusion_low_cpu_mem_usage_v_pred(self):
assert 2 * low_cpu_mem_usage_time < normal_load_time
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading_v_pred(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipeline_id = "stabilityai/stable-diffusion-2"
prompt = "Andromeda galaxy in a bottle"
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, torch_dtype=torch.float16)
pipeline.enable_attention_slicing(1)
- pipeline.enable_sequential_cpu_offload()
+ pipeline.enable_sequential_cpu_offload(device=torch_device)
generator = torch.manual_seed(0)
_ = pipeline(prompt, generator=generator, num_inference_steps=5)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.8 GB is allocated
assert mem_bytes < 2.8 * 10**9
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index 94a85a56f510..38ef6143f4c0 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -2,13 +2,15 @@
import unittest
import numpy as np
+import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_big_accelerator,
slow,
torch_device,
)
@@ -34,6 +36,8 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
@@ -102,6 +106,8 @@ def get_dummy_components(self):
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -151,39 +157,6 @@ def test_stable_diffusion_3_different_negative_prompts(self):
# Outputs should be different here
assert max_diff > 1e-2
- def test_stable_diffusion_3_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- do_classifier_free_guidance = inputs["guidance_scale"] > 1
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- prompt_3=None,
- do_classifier_free_guidance=do_classifier_free_guidance,
- device=torch_device,
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -224,9 +197,43 @@ def test_fused_qkv_projections(self):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
+ def test_skip_guidance_layers(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ output_full = pipe(**inputs)[0]
+
+ inputs_with_skip = inputs.copy()
+ inputs_with_skip["skip_guidance_layers"] = [0]
+ output_skip = pipe(**inputs_with_skip)[0]
+
+ self.assertFalse(
+ np.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
+ )
+
+ self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
+
+ inputs["num_images_per_prompt"] = 2
+ output_full = pipe(**inputs)[0]
+
+ inputs_with_skip = inputs.copy()
+ inputs_with_skip["skip_guidance_layers"] = [0]
+ output_skip = pipe(**inputs_with_skip)[0]
+
+ self.assertFalse(
+ np.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
+ )
+
+ self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
+
@slow
-@require_torch_gpu
+@require_big_accelerator
+@pytest.mark.big_gpu_with_torch_cuda
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
@@ -234,12 +241,12 @@ class StableDiffusion3PipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
@@ -257,7 +264,7 @@ def get_inputs(self, device, seed=0):
def test_sd3_inference(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
index 9d131b28c308..f7c450aab93e 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
@@ -3,6 +3,7 @@
import unittest
import numpy as np
+import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -14,9 +15,10 @@
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
floats_tensor,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_big_accelerator,
slow,
torch_device,
)
@@ -104,6 +106,8 @@ def get_dummy_components(self):
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -156,45 +160,14 @@ def test_stable_diffusion_3_img2img_different_negative_prompts(self):
# Outputs should be different here
assert max_diff > 1e-2
- def test_stable_diffusion_3_img2img_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- do_classifier_free_guidance = inputs["guidance_scale"] > 1
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- prompt_3=None,
- do_classifier_free_guidance=do_classifier_free_guidance,
- device=torch_device,
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
+ @unittest.skip("Skip for now.")
def test_multi_vae(self):
pass
@slow
-@require_torch_gpu
+@require_big_accelerator
+@pytest.mark.big_gpu_with_torch_cuda
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
@@ -202,12 +175,12 @@ class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
init_image = load_image(
@@ -229,11 +202,10 @@ def get_inputs(self, device, seed=0):
}
def test_sd3_img2img_inference(self):
+ torch.manual_seed(0)
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
-
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
-
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
index 464ef6d017df..4090306dec72 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
@@ -106,6 +106,8 @@ def get_dummy_components(self):
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
@@ -162,38 +164,5 @@ def test_stable_diffusion_3_inpaint_different_negative_prompts(self):
# Outputs should be different here
assert max_diff > 1e-2
- def test_stable_diffusion_3_inpaint_prompt_embeds(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
- inputs = self.get_dummy_inputs(torch_device)
-
- output_with_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = inputs.pop("prompt")
-
- do_classifier_free_guidance = inputs["guidance_scale"] > 1
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = pipe.encode_prompt(
- prompt,
- prompt_2=None,
- prompt_3=None,
- do_classifier_free_guidance=do_classifier_free_guidance,
- device=torch_device,
- )
- output_with_embeds = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- **inputs,
- ).images[0]
-
- max_diff = np.abs(output_with_prompt - output_with_embeds).max()
- assert max_diff < 1e-4
-
def test_multi_vae(self):
pass
diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
index 2a1e691e9e8f..009c75df4249 100644
--- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
+++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
@@ -35,12 +35,13 @@
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -336,6 +337,13 @@ def test_adapter_lcm_custom_timesteps(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
class StableDiffusionFullAdapterPipelineFastTests(
AdapterTests, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
@@ -389,6 +397,8 @@ def test_stable_diffusion_adapter_default_case(self):
class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
+ supports_dduf = False
+
def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
@@ -595,17 +605,17 @@ def test_inference_batch_single_identical(
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionAdapterPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_diffusion_adapter_depth_sd_v15(self):
adapter_model = "TencentARC/t2iadapter_depth_sd15v2"
diff --git a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py b/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
index 405809aee19e..b3ac507f768e 100644
--- a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
+++ b/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
@@ -169,3 +169,7 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
+
+ @unittest.skip("Test not supported as tokenizer is used for parsing bounding boxes.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
index 748702541b1e..b080bb987e13 100644
--- a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
+++ b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
@@ -66,6 +66,8 @@ class GligenTextImagePipelineFastTests(
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -205,3 +207,9 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
+
+ @unittest.skip(
+ "Test not supported because of the use of `text_encoder` in `get_cross_attention_kwargs_with_grounded()`."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
index 7a3b0f70ccb1..f706e7000b28 100644
--- a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
+++ b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
@@ -30,13 +30,17 @@
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -58,6 +62,8 @@ class StableDiffusionImageVariationPipelineFastTests(
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -162,17 +168,17 @@ def test_inference_batch_single_identical(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -256,37 +262,37 @@ def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
assert number_of_steps == inputs["num_inference_steps"]
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
"lambdalabs/sd-image-variations-diffusers", safety_checker=None, torch_dtype=torch.float16
)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device, dtype=torch.float16)
_ = pipe(**inputs)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.6 GB is allocated
assert mem_bytes < 2.6 * 10**9
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionImageVariationPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
index 6dc6c31ae9a7..4734af259921 100644
--- a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
+++ b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
@@ -258,6 +258,13 @@ def test_stable_diffusion_panorama_pndm(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@nightly
@require_torch_gpu
diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
index ccb20a1c218e..269677c08345 100644
--- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
+++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
@@ -24,7 +24,7 @@
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
-from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, torch_device
+from diffusers.utils.testing_utils import floats_tensor, nightly, require_accelerator, require_torch_gpu, torch_device
class SafeDiffusionPipelineFastTests(unittest.TestCase):
@@ -228,7 +228,7 @@ def test_stable_diffusion_no_safety_checker(self):
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
- @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
+ @require_accelerator
def test_stable_diffusion_fp16(self):
"""Test that stable diffusion works with fp16"""
unet = self.dummy_cond_unet
diff --git a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
index 1d4e66bd65f0..bd1ba268d2d9 100644
--- a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
+++ b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
@@ -153,6 +153,13 @@ def test_pipeline_different_schedulers(self):
# Karras schedulers are not supported
image = pipeline(**inputs).images[0]
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@nightly
@require_torch_gpu
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index 8550f258045e..c68cdf67036a 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -38,7 +38,7 @@
enable_full_determinism,
load_image,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -54,7 +54,6 @@
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -66,7 +65,6 @@ class StableDiffusionXLPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLPipeline
@@ -75,6 +73,8 @@ class StableDiffusionXLPipelineFastTests(
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
+ test_layerwise_casting = True
+ test_group_offloading = True
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
@@ -252,84 +252,6 @@ def test_stable_diffusion_ays(self):
np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
), "use ays sigmas should have different outputs"
- def test_stable_diffusion_xl_prompt_embeds(self):
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionXLPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 2 * [inputs["prompt"]]
- inputs["num_images_per_prompt"] = 2
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 2 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = sd_pipe.encode_prompt(prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- def test_stable_diffusion_xl_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionXLPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- prompt = 3 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = sd_pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
@@ -343,10 +265,7 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
- @require_torch_gpu
+ @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -355,12 +274,12 @@ def test_stable_diffusion_xl_offloads(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -375,41 +294,9 @@ def test_stable_diffusion_xl_offloads(self):
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
- def test_stable_diffusion_xl_img2img_prompt_embeds_only(self):
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionXLPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- _,
- pooled_prompt_embeds,
- _,
- ) = sd_pipe.encode_prompt(prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
def test_stable_diffusion_two_xl_mixture_of_denoiser_fast(self):
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
index 2091af9c0383..07333623867e 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
@@ -42,7 +42,6 @@
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
assert_mean_pixel_difference,
)
@@ -50,9 +49,7 @@
enable_full_determinism()
-class StableDiffusionXLAdapterPipelineFastTests(
- IPAdapterTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
-):
+class StableDiffusionXLAdapterPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionXLAdapterPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
@@ -300,6 +297,10 @@ def test_ip_adapter(self, from_multi=False, expected_pipe_slice=None):
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
+
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -373,9 +374,6 @@ def test_total_downscale_factor(self, adapter_type):
expected_out_image_size,
)
- def test_save_load_optional_components(self):
- return self._test_save_load_optional_components()
-
def test_adapter_sdxl_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -422,6 +420,8 @@ def test_adapter_sdxl_lcm_custom_timesteps(self):
class StableDiffusionXLMultiAdapterPipelineFastTests(
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
):
+ supports_dduf = False
+
def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
@@ -513,6 +513,10 @@ def test_inference_batch_consistent(
logger.setLevel(level=diffusers.logging.WARNING)
+ @unittest.skip("We test this functionality elsewhere already.")
+ def test_save_load_optional_components(self):
+ pass
+
def test_num_images_per_prompt(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -642,9 +646,6 @@ def test_adapter_sdxl_lcm(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
- debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
- print(",".join(debug))
-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_adapter_sdxl_lcm_custom_timesteps(self):
@@ -667,7 +668,4 @@ def test_adapter_sdxl_lcm_custom_timesteps(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
- debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
- print(",".join(debug))
-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
index db0905a48310..9a141634a364 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
@@ -42,7 +42,7 @@
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -57,7 +57,6 @@
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -77,6 +76,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
)
+ supports_dduf = False
+
def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -264,52 +265,10 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
- # TODO(Patrick, Sayak) - skip for now as this requires more refiner tests
+ @unittest.skip("Skip for now.")
def test_save_load_optional_components(self):
pass
- def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- prompt = 3 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = sd_pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
@@ -334,7 +293,7 @@ def test_stable_diffusion_xl_img2img_tiny_autoencoder(self):
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
- @require_torch_gpu
+ @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -343,12 +302,12 @@ def test_stable_diffusion_xl_offloads(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -517,7 +476,7 @@ def callback_on_step_end(pipe, i, t, callback_kwargs):
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
- PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
+ PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
@@ -637,7 +596,7 @@ def test_stable_diffusion_xl_img2img_euler(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- @require_torch_gpu
+ @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -646,12 +605,12 @@ def test_stable_diffusion_xl_offloads(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -695,92 +654,15 @@ def test_stable_diffusion_xl_img2img_negative_conditions(self):
> 1e-4
)
- def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- prompt = 3 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = sd_pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- def test_stable_diffusion_xl_img2img_prompt_embeds_only(self):
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- # forward without prompt embeds
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with prompt embeds
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- (
- prompt_embeds,
- _,
- pooled_prompt_embeds,
- _,
- ) = sd_pipe.encode_prompt(prompt)
-
- output = sd_pipe(
- **inputs,
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- )
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # make sure that it's equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+ @unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
+ pass
@slow
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
index 964c7123dd32..66ae581a0529 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
@@ -41,7 +41,13 @@
UNet2DConditionModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
@@ -72,6 +78,8 @@ class StableDiffusionXLInpaintPipelineFastTests(
}
)
+ supports_dduf = False
+
def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
@@ -299,10 +307,11 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
- # TODO(Patrick, Sayak) - skip for now as this requires more refiner tests
+ @unittest.skip("Skip for now.")
def test_save_load_optional_components(self):
pass
+ @require_torch_accelerator
def test_stable_diffusion_xl_inpaint_negative_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
@@ -343,7 +352,7 @@ def test_stable_diffusion_xl_inpaint_negative_prompt_embeds(self):
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
- @require_torch_gpu
+ @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -352,12 +361,12 @@ def test_stable_diffusion_xl_offloads(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
index 98cecb4e0f7c..79d38c4a7b43 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
@@ -40,7 +40,6 @@
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -51,7 +50,6 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
@@ -182,8 +180,10 @@ def test_latents_input(self):
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
+ @unittest.skip("Test not supported at the moment.")
def test_cfg(self):
pass
+ @unittest.skip("Functionality is tested elsewhere.")
def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
+ pass
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
index 94ee9f0facc8..46f7d0e7b0b4 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
@@ -20,14 +20,20 @@
import torch
from diffusers import StableDiffusionXLKDiffusionPipeline
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
enable_full_determinism()
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLKPipelineIntegrationTests(unittest.TestCase):
dtype = torch.float16
@@ -35,13 +41,13 @@ def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_diffusion_xl(self):
sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py
index bb54d212a786..8cf103dffd56 100644
--- a/tests/pipelines/stable_unclip/test_stable_unclip.py
+++ b/tests/pipelines/stable_unclip/test_stable_unclip.py
@@ -184,6 +184,10 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
+ @unittest.skip("Test not supported because of the use of `_encode_prior_prompt()`.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
@nightly
@require_torch_gpu
diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
index a5cbf7761501..176b6954d616 100644
--- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
+++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
@@ -51,6 +51,8 @@ class StableUnCLIPImg2ImgPipelineFastTests(
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
+ supports_dduf = False
+
def get_dummy_components(self):
embedder_hidden_size = 32
embedder_projection_dim = embedder_hidden_size
@@ -205,6 +207,10 @@ def test_inference_batch_single_identical(self):
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(test_max_difference=False)
+ @unittest.skip("Test not supported at the moment.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
@nightly
@require_torch_gpu
diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
index 60fc21e2027b..f77a5b1620d2 100644
--- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
+++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
@@ -18,14 +18,17 @@
StableVideoDiffusionPipeline,
UNetSpatioTemporalConditionModel,
)
-from diffusers.utils import is_accelerate_available, is_accelerate_version, load_image, logging
+from diffusers.utils import load_image, logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
CaptureLogger,
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_accelerate_version_greater,
+ require_accelerator,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -56,6 +59,8 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
]
)
+ supports_dduf = False
+
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNetSpatioTemporalConditionModel(
@@ -250,7 +255,8 @@ def test_float16_inference(self, expected_max_diff=5e-2):
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
@@ -366,7 +372,7 @@ def test_save_load_local(self, expected_max_difference=9e-4):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -381,14 +387,14 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu")).frames[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda")).frames[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device)).frames[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -402,10 +408,8 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.14.0")
def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -419,7 +423,7 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
inputs = self.get_dummy_inputs(generator_device)
output_without_offload = pipe(**inputs).frames[0]
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs).frames[0]
@@ -427,10 +431,8 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.17.0")
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
generator_device = "cpu"
components = self.get_dummy_components()
@@ -446,7 +448,7 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
inputs = self.get_dummy_inputs(generator_device)
output_without_offload = pipe(**inputs).frames[0]
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs).frames[0]
@@ -514,19 +516,19 @@ def test_disable_cfg(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableVideoDiffusionPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_sd_video(self):
pipe = StableVideoDiffusionPipeline.from_pretrained(
@@ -534,7 +536,7 @@ def test_sd_video(self):
variant="fp16",
torch_dtype=torch.float16,
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py
index 5eedd393c8f8..423c2b8ab146 100644
--- a/tests/pipelines/test_pipeline_utils.py
+++ b/tests/pipelines/test_pipeline_utils.py
@@ -18,8 +18,8 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
-from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
-from diffusers.utils.testing_utils import torch_device
+from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
+from diffusers.utils.testing_utils import require_torch_gpu, torch_device
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -197,6 +197,388 @@ def test_diffusers_is_compatible_only_variants(self):
]
self.assertTrue(is_safetensors_compatible(filenames))
+ def test_diffusers_is_compatible_no_components(self):
+ filenames = [
+ "diffusion_pytorch_model.bin",
+ ]
+ self.assertFalse(is_safetensors_compatible(filenames))
+
+ def test_diffusers_is_compatible_no_components_only_variants(self):
+ filenames = [
+ "diffusion_pytorch_model.fp16.bin",
+ ]
+ self.assertFalse(is_safetensors_compatible(filenames))
+
+
+class VariantCompatibleSiblingsTest(unittest.TestCase):
+ def test_only_non_variants_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ f"vae/diffusion_pytorch_model.{variant}.safetensors",
+ "vae/diffusion_pytorch_model.safetensors",
+ f"text_encoder/model.{variant}.safetensors",
+ "text_encoder/model.safetensors",
+ f"unet/diffusion_pytorch_model.{variant}.safetensors",
+ "unet/diffusion_pytorch_model.safetensors",
+ ]
+
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=None, ignore_patterns=ignore_patterns
+ )
+ assert all(variant not in f for f in model_filenames)
+
+ def test_only_variants_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ f"vae/diffusion_pytorch_model.{variant}.safetensors",
+ "vae/diffusion_pytorch_model.safetensors",
+ f"text_encoder/model.{variant}.safetensors",
+ "text_encoder/model.safetensors",
+ f"unet/diffusion_pytorch_model.{variant}.safetensors",
+ "unet/diffusion_pytorch_model.safetensors",
+ ]
+
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f for f in model_filenames)
+
+ def test_mixed_variants_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ non_variant_file = "text_encoder/model.safetensors"
+ filenames = [
+ f"vae/diffusion_pytorch_model.{variant}.safetensors",
+ "vae/diffusion_pytorch_model.safetensors",
+ "text_encoder/model.safetensors",
+ f"unet/diffusion_pytorch_model.{variant}.safetensors",
+ "unet/diffusion_pytorch_model.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
+
+ def test_non_variants_in_main_dir_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ f"diffusion_pytorch_model.{variant}.safetensors",
+ "diffusion_pytorch_model.safetensors",
+ "model.safetensors",
+ f"model.{variant}.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=None, ignore_patterns=ignore_patterns
+ )
+ assert all(variant not in f for f in model_filenames)
+
+ def test_variants_in_main_dir_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ f"diffusion_pytorch_model.{variant}.safetensors",
+ "diffusion_pytorch_model.safetensors",
+ "model.safetensors",
+ f"model.{variant}.safetensors",
+ f"diffusion_pytorch_model.{variant}.safetensors",
+ "diffusion_pytorch_model.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f for f in model_filenames)
+
+ def test_mixed_variants_in_main_dir_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ non_variant_file = "model.safetensors"
+ filenames = [
+ f"diffusion_pytorch_model.{variant}.safetensors",
+ "diffusion_pytorch_model.safetensors",
+ "model.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
+
+ def test_sharded_variants_in_main_dir_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ "diffusion_pytorch_model.safetensors.index.json",
+ "diffusion_pytorch_model-00001-of-00003.safetensors",
+ "diffusion_pytorch_model-00002-of-00003.safetensors",
+ "diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
+ f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
+ f"diffusion_pytorch_model.safetensors.index.{variant}.json",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f for f in model_filenames)
+
+ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ "diffusion_pytorch_model.safetensors.index.json",
+ "diffusion_pytorch_model-00001-of-00003.safetensors",
+ "diffusion_pytorch_model-00002-of-00003.safetensors",
+ "diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"diffusion_pytorch_model.{variant}.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f for f in model_filenames)
+
+ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ f"diffusion_pytorch_model.safetensors.index.{variant}.json",
+ "diffusion_pytorch_model.safetensors.index.json",
+ "diffusion_pytorch_model-00001-of-00003.safetensors",
+ "diffusion_pytorch_model-00002-of-00003.safetensors",
+ "diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
+ f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=None, ignore_patterns=ignore_patterns
+ )
+ assert all(variant not in f for f in model_filenames)
+
+ def test_sharded_non_variants_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
+ "unet/diffusion_pytorch_model.safetensors.index.json",
+ "unet/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
+ f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=None, ignore_patterns=ignore_patterns
+ )
+ assert all(variant not in f for f in model_filenames)
+
+ def test_sharded_variants_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
+ "unet/diffusion_pytorch_model.safetensors.index.json",
+ "unet/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
+ f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f for f in model_filenames)
+ assert model_filenames == variant_filenames
+
+ def test_single_variant_with_sharded_non_variant_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ "unet/diffusion_pytorch_model.safetensors.index.json",
+ "unet/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"unet/diffusion_pytorch_model.{variant}.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f for f in model_filenames)
+
+ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ allowed_non_variant = "unet"
+ filenames = [
+ "vae/diffusion_pytorch_model.safetensors.index.json",
+ "vae/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"vae/diffusion_pytorch_model.{variant}.safetensors",
+ "unet/diffusion_pytorch_model.safetensors.index.json",
+ "unet/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00003-of-00003.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
+
+ def test_sharded_mixed_variants_downloaded(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ allowed_non_variant = "unet"
+ filenames = [
+ f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json",
+ "vae/diffusion_pytorch_model.safetensors.index.json",
+ "unet/diffusion_pytorch_model.safetensors.index.json",
+ "unet/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
+ f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
+ "vae/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00003-of-00003.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
+
+ def test_downloading_when_no_variant_exists(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"]
+ with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+
+ def test_downloading_use_safetensors_false(self):
+ ignore_patterns = ["*.safetensors"]
+ filenames = [
+ "text_encoder/model.bin",
+ "unet/diffusion_pytorch_model.bin",
+ "unet/diffusion_pytorch_model.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=None, ignore_patterns=ignore_patterns
+ )
+
+ assert all(".safetensors" not in f for f in model_filenames)
+
+ def test_non_variant_in_main_dir_with_variant_in_subfolder(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ allowed_non_variant = "diffusion_pytorch_model.safetensors"
+ filenames = [
+ f"unet/diffusion_pytorch_model.{variant}.safetensors",
+ "diffusion_pytorch_model.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
+
+ def test_download_variants_when_component_has_no_safetensors_variant(self):
+ ignore_patterns = None
+ variant = "fp16"
+ filenames = [
+ f"unet/diffusion_pytorch_model.{variant}.bin",
+ "vae/diffusion_pytorch_model.safetensors",
+ f"vae/diffusion_pytorch_model.{variant}.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert {
+ f"unet/diffusion_pytorch_model.{variant}.bin",
+ f"vae/diffusion_pytorch_model.{variant}.safetensors",
+ } == model_filenames
+
+ def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self):
+ ignore_patterns = ["*.bin"]
+ variant = "fp16"
+ filenames = [
+ f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
+ "vae/diffusion_pytorch_model.safetensors.index.json",
+ f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
+ "vae/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00003-of-00003.safetensors",
+ "unet/diffusion_pytorch_model.safetensors.index.json",
+ "unet/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
+ ]
+ with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "):
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+
+ def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self):
+ ignore_patterns = ["*.safetensors"]
+ allowed_non_variant = "unet"
+ variant = "fp16"
+ filenames = [
+ f"vae/diffusion_pytorch_model.bin.index.{variant}.json",
+ "vae/diffusion_pytorch_model.safetensors.index.json",
+ f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin",
+ "vae/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00003-of-00003.safetensors",
+ "unet/diffusion_pytorch_model.safetensors.index.json",
+ "unet/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "unet/diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
+
+ def test_download_sharded_legacy_variants(self):
+ ignore_patterns = None
+ variant = "fp16"
+ filenames = [
+ f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json",
+ "vae/diffusion_pytorch_model.safetensors.index.json",
+ f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors",
+ "vae/diffusion_pytorch_model-00001-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00002-of-00003.safetensors",
+ "vae/diffusion_pytorch_model-00003-of-00003.safetensors",
+ f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=variant, ignore_patterns=ignore_patterns
+ )
+ assert all(variant in f for f in model_filenames)
+
+ def test_download_onnx_models(self):
+ ignore_patterns = ["*.safetensors"]
+ filenames = [
+ "vae/model.onnx",
+ "unet/model.onnx",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=None, ignore_patterns=ignore_patterns
+ )
+ assert model_filenames == set(filenames)
+
+ def test_download_flax_models(self):
+ ignore_patterns = ["*.safetensors", "*.bin"]
+ filenames = [
+ "vae/diffusion_flax_model.msgpack",
+ "unet/diffusion_flax_model.msgpack",
+ ]
+ model_filenames, variant_filenames = variant_compatible_siblings(
+ filenames, variant=None, ignore_patterns=ignore_patterns
+ )
+ assert model_filenames == set(filenames)
+
class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):
@@ -444,3 +826,104 @@ def test_video_to_video(self):
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
+
+
+@require_torch_gpu
+class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
+ expected_pipe_device = torch.device("cuda:0")
+ expected_pipe_dtype = torch.float64
+
+ def get_dummy_components_image_generation(self):
+ cross_attention_dim = 8
+
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(4, 8),
+ layers_per_block=1,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=2,
+ )
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[4, 8],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=2,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=cross_attention_dim,
+ intermediate_size=16,
+ layer_norm_eps=1e-05,
+ num_attention_heads=2,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ "image_encoder": None,
+ }
+ return components
+
+ def test_deterministic_device(self):
+ components = self.get_dummy_components_image_generation()
+
+ pipe = StableDiffusionPipeline(**components)
+ pipe.to(device=torch_device, dtype=torch.float32)
+
+ pipe.unet.to(device="cpu")
+ pipe.vae.to(device="cuda")
+ pipe.text_encoder.to(device="cuda:0")
+
+ pipe_device = pipe.device
+
+ self.assertEqual(
+ self.expected_pipe_device,
+ pipe_device,
+ f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
+ )
+
+ def test_deterministic_dtype(self):
+ components = self.get_dummy_components_image_generation()
+
+ pipe = StableDiffusionPipeline(**components)
+ pipe.to(device=torch_device, dtype=torch.float32)
+
+ pipe.unet.to(dtype=torch.float16)
+ pipe.vae.to(dtype=torch.float32)
+ pipe.text_encoder.to(dtype=torch.float64)
+
+ pipe_dtype = pipe.dtype
+
+ self.assertEqual(
+ self.expected_pipe_dtype,
+ pipe_dtype,
+ f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
+ )
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 43b01c40f5bb..48c89d399216 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -66,6 +66,7 @@
)
from diffusers.utils.testing_utils import (
CaptureLogger,
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
get_python_version,
@@ -75,9 +76,11 @@
nightly,
require_compel,
require_flax,
+ require_hf_hub_version_greater,
require_onnxruntime,
require_torch_2,
- require_torch_gpu,
+ require_torch_accelerator,
+ require_transformers_version_greater,
run_test_in_subprocess,
slow,
torch_device,
@@ -981,6 +984,18 @@ def test_download_ignore_files(self):
assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files)
assert len(files) == 14
+ def test_download_dduf_with_custom_pipeline_raises_error(self):
+ with self.assertRaises(NotImplementedError):
+ _ = DiffusionPipeline.download(
+ "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline"
+ )
+
+ def test_download_dduf_with_connected_pipeline_raises_error(self):
+ with self.assertRaises(NotImplementedError):
+ _ = DiffusionPipeline.download(
+ "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True
+ )
+
def test_get_pipeline_class_from_flax(self):
flax_config = {"_class_name": "FlaxStableDiffusionPipeline"}
config = {"_class_name": "StableDiffusionPipeline"}
@@ -1136,7 +1151,7 @@ def test_custom_model_and_pipeline(self):
assert conf_1 == conf_2
@slow
- @require_torch_gpu
+ @require_torch_accelerator
def test_download_from_git(self):
# Because adaptive_avg_pool2d_backward_cuda
# does not have a deterministic implementation.
@@ -1350,7 +1365,7 @@ def test_stable_diffusion_components(self):
assert image_img2img.shape == (1, 32, 32, 3)
assert image_text2img.shape == (1, 64, 64, 3)
- @require_torch_gpu
+ @require_torch_accelerator
def test_pipe_false_offload_warn(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
@@ -1368,11 +1383,11 @@ def test_pipe_false_offload_warn(self):
feature_extractor=self.dummy_extractor,
)
- sd.enable_model_cpu_offload()
+ sd.enable_model_cpu_offload(device=torch_device)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
with CaptureLogger(logger) as cap_logger:
- sd.to("cuda")
+ sd.to(torch_device)
assert "It is strongly recommended against doing so" in str(cap_logger)
@@ -1802,21 +1817,101 @@ def test_pipe_same_device_id_offload(self):
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5
+ @parameterized.expand([torch.float32, torch.float16])
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_load_dduf_from_hub(self, dtype):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe = DiffusionPipeline.from_pretrained(
+ "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, torch_dtype=dtype
+ ).to(torch_device)
+ out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
+
+ pipe.save_pretrained(tmpdir)
+ loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=dtype).to(torch_device)
+
+ out_2 = loaded_pipe(
+ prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
+ ).images
+
+ self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
+
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_load_dduf_from_hub_local_files_only(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe = DiffusionPipeline.from_pretrained(
+ "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir
+ ).to(torch_device)
+ out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
+
+ local_files_pipe = DiffusionPipeline.from_pretrained(
+ "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, local_files_only=True
+ ).to(torch_device)
+ out_2 = local_files_pipe(
+ prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
+ ).images
+
+ self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
+
+ def test_dduf_raises_error_with_custom_pipeline(self):
+ with self.assertRaises(NotImplementedError):
+ _ = DiffusionPipeline.from_pretrained(
+ "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline"
+ )
+
+ def test_dduf_raises_error_with_connected_pipeline(self):
+ with self.assertRaises(NotImplementedError):
+ _ = DiffusionPipeline.from_pretrained(
+ "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True
+ )
+
+ def test_wrong_model(self):
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ with self.assertRaises(ValueError) as error_context:
+ _ = StableDiffusionPipeline.from_pretrained(
+ "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
+ )
+
+ assert "is of type" in str(error_context.exception)
+ assert "but should be" in str(error_context.exception)
+
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_dduf_load_sharded_checkpoint_diffusion_model(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-flux-dev-pipe-sharded-checkpoint-DDUF",
+ dduf_file="tiny-flux-dev-pipe-sharded-checkpoint.dduf",
+ cache_dir=tmpdir,
+ ).to(torch_device)
+
+ out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
+
+ pipe.save_pretrained(tmpdir)
+ loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir).to(torch_device)
+
+ out_2 = loaded_pipe(
+ prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
+ ).images
+
+ self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
+
@slow
-@require_torch_gpu
+@require_torch_accelerator
class PipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_smart_download(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
@@ -2008,7 +2103,7 @@ def test_weighted_prompts_compel(self):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.enable_attention_slicing()
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
@@ -2035,19 +2130,19 @@ def test_weighted_prompts_compel(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class PipelineNightlyTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_ddpm_ddim_equality_batched(self):
seed = 0
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 3e6f9d1278e8..d3e39e363f91 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -23,35 +23,49 @@
ConsistencyDecoderVAE,
DDIMScheduler,
DiffusionPipeline,
+ FasterCacheConfig,
KolorsPipeline,
+ PyramidAttentionBroadcastConfig,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
+ apply_faster_cache,
)
+from diffusers.hooks import apply_group_offloading
+from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
+from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
-from diffusers.loaders import IPAdapterMixin
+from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor
-from diffusers.models.controlnet_xs import UNetControlNetXSModel
+from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
from diffusers.models.unets.unet_motion_model import UNetMotionModel
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
-from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
from diffusers.utils.testing_utils import (
CaptureLogger,
+ backend_empty_cache,
+ require_accelerate_version_greater,
+ require_accelerator,
+ require_hf_hub_version_greater,
require_torch,
+ require_torch_gpu,
+ require_transformers_version_greater,
skip_mps,
torch_device,
)
-from ..models.autoencoders.test_models_vae import (
+from ..models.autoencoders.vae import (
get_asym_autoencoder_kl_config,
get_autoencoder_kl_config,
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
+from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict,
@@ -481,6 +495,135 @@ def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
)
+class FluxIPAdapterTesterMixin:
+ """
+ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
+ It provides a set of common tests for pipelines that support IP Adapters.
+ """
+
+ def test_pipeline_signature(self):
+ parameters = inspect.signature(self.pipeline_class.__call__).parameters
+
+ assert issubclass(self.pipeline_class, FluxIPAdapterMixin)
+ self.assertIn(
+ "ip_adapter_image",
+ parameters,
+ "`ip_adapter_image` argument must be supported by the `__call__` method",
+ )
+ self.assertIn(
+ "ip_adapter_image_embeds",
+ parameters,
+ "`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
+ )
+
+ def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
+ return torch.randn((1, 1, image_embed_dim), device=torch_device)
+
+ def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
+ inputs["negative_prompt"] = ""
+ inputs["true_cfg_scale"] = 4.0
+ inputs["output_type"] = "np"
+ inputs["return_dict"] = False
+ return inputs
+
+ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
+ r"""Tests for IP-Adapter.
+
+ The following scenarios are tested:
+ - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
+ - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
+ - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
+ - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
+ """
+ # Raising the tolerance for this test when it's run on a CPU because we
+ # compare against static slices and that can be shaky (with a VVVV low probability).
+ expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components).to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ image_embed_dim = pipe.transformer.config.pooled_projection_dim
+
+ # forward pass without ip adapter
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ if expected_pipe_slice is None:
+ output_without_adapter = pipe(**inputs)[0]
+ else:
+ output_without_adapter = expected_pipe_slice
+
+ # 1. Single IP-Adapter test cases
+ adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer)
+ pipe.transformer._load_ip_adapter_weights(adapter_state_dict)
+
+ # forward pass with single ip adapter, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
+ inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
+ pipe.set_ip_adapter_scale(0.0)
+ output_without_adapter_scale = pipe(**inputs)[0]
+ if expected_pipe_slice is not None:
+ output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single ip adapter, but with scale of adapter weights
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
+ inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
+ pipe.set_ip_adapter_scale(42.0)
+ output_with_adapter_scale = pipe(**inputs)[0]
+ if expected_pipe_slice is not None:
+ output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
+ max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
+
+ self.assertLess(
+ max_diff_without_adapter_scale,
+ expected_max_diff,
+ "Output without ip-adapter must be same as normal inference",
+ )
+ self.assertGreater(
+ max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
+ )
+
+ # 2. Multi IP-Adapter test cases
+ adapter_state_dict_1 = create_flux_ip_adapter_state_dict(pipe.transformer)
+ adapter_state_dict_2 = create_flux_ip_adapter_state_dict(pipe.transformer)
+ pipe.transformer._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
+
+ # forward pass with multi ip adapter, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2
+ inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2
+ pipe.set_ip_adapter_scale([0.0, 0.0])
+ output_without_multi_adapter_scale = pipe(**inputs)[0]
+ if expected_pipe_slice is not None:
+ output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with multi ip adapter, but with scale of adapter weights
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2
+ inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)] * 2
+ pipe.set_ip_adapter_scale([42.0, 42.0])
+ output_with_multi_adapter_scale = pipe(**inputs)[0]
+ if expected_pipe_slice is not None:
+ output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_multi_adapter_scale = np.abs(
+ output_without_multi_adapter_scale - output_without_adapter
+ ).max()
+ max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
+ self.assertLess(
+ max_diff_without_multi_adapter_scale,
+ expected_max_diff,
+ "Output without multi-ip-adapter must be same as normal inference",
+ )
+ self.assertGreater(
+ max_diff_with_multi_adapter_scale,
+ 1e-2,
+ "Output with multi-ip-adapter scale must be different from normal inference",
+ )
+
+
class PipelineLatentTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
@@ -770,17 +913,15 @@ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3):
type(proc) == AttnProcessor for proc in component.attn_processors.values()
), "`from_pipe` changed the attention processor in original pipeline."
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.14.0")
def test_from_pipe_consistent_forward_pass_cpu_offload(self, expected_max_diff=1e-3):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_pipe(torch_device)
output = pipe(**inputs)[0]
@@ -815,7 +956,7 @@ def test_from_pipe_consistent_forward_pass_cpu_offload(self, expected_max_diff=1
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
- pipe_from_original.enable_model_cpu_offload()
+ pipe_from_original.enable_model_cpu_offload(device=torch_device)
pipe_from_original.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_pipe(torch_device)
output_from_original = pipe_from_original(**inputs)[0]
@@ -896,6 +1037,9 @@ class PipelineTesterMixin:
test_attention_slicing = True
test_xformers_attention = True
+ test_layerwise_casting = False
+ test_group_offloading = False
+ supports_dduf = True
def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu"
@@ -968,13 +1112,13 @@ def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
@@ -1201,7 +1345,8 @@ def test_components_function(self):
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_float16_inference(self, expected_max_diff=5e-2):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -1238,7 +1383,8 @@ def test_float16_inference(self, expected_max_diff=5e-2):
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
@@ -1281,7 +1427,6 @@ def test_save_load_float16(self, expected_max_diff=1e-2):
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
-
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
@@ -1296,6 +1441,7 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
@@ -1314,12 +1460,13 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
)
inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -1332,11 +1479,11 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
@@ -1393,10 +1540,8 @@ def _test_attention_slicing_forward_pass(
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.14.0")
def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
import accelerate
@@ -1410,12 +1555,14 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
output_without_offload = pipe(**inputs)[0]
- pipe.enable_sequential_cpu_offload()
- assert pipe._execution_device.type == "cuda"
+ pipe.enable_sequential_cpu_offload(device=torch_device)
+ assert pipe._execution_device.type == torch_device
inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
output_with_offload = pipe(**inputs)[0]
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
@@ -1456,10 +1603,8 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
)
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.17.0")
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
import accelerate
@@ -1475,12 +1620,14 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
output_without_offload = pipe(**inputs)[0]
- pipe.enable_model_cpu_offload()
- assert pipe._execution_device.type == "cuda"
+ pipe.enable_model_cpu_offload(device=torch_device)
+ assert pipe._execution_device.type == torch_device
inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
output_with_offload = pipe(**inputs)[0]
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
@@ -1513,10 +1660,8 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
)
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.17.0")
def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
import accelerate
@@ -1530,11 +1675,11 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
pipe.set_progress_bar_config(disable=None)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs)[0]
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(generator_device)
output_with_offload_twice = pipe(**inputs)[0]
@@ -1570,10 +1715,8 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
)
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.14.0")
def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
import accelerate
@@ -1587,11 +1730,11 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
pipe.set_progress_bar_config(disable=None)
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs)[0]
- pipe.enable_sequential_cpu_offload()
+ pipe.enable_sequential_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(generator_device)
output_with_offload_twice = pipe(**inputs)[0]
@@ -1892,6 +2035,118 @@ def test_loading_with_incorrect_variants_raises_error(self):
assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception)
+ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
+ if not hasattr(self.pipeline_class, "encode_prompt"):
+ return
+
+ components = self.get_dummy_components()
+
+ # We initialize the pipeline with only text encoders and tokenizers,
+ # mimicking a real-world scenario.
+ components_with_text_encoders = {}
+ for k in components:
+ if "text" in k or "tokenizer" in k:
+ components_with_text_encoders[k] = components[k]
+ else:
+ components_with_text_encoders[k] = None
+ pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders)
+ pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
+
+ # Get inputs and also the args of `encode_prompts`.
+ inputs = self.get_dummy_inputs(torch_device)
+ encode_prompt_signature = inspect.signature(pipe_with_just_text_encoder.encode_prompt)
+ encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
+
+ # Required args in encode_prompt with those with no default.
+ required_params = []
+ for param in encode_prompt_parameters:
+ if param.name == "self" or param.name == "kwargs":
+ continue
+ if param.default is inspect.Parameter.empty:
+ required_params.append(param.name)
+
+ # Craft inputs for the `encode_prompt()` method to run in isolation.
+ encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
+ input_keys = list(inputs.keys())
+ encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names}
+
+ pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__)
+ pipe_call_parameters = pipe_call_signature.parameters
+
+ # For each required arg in encode_prompt, check if it's missing
+ # in encode_prompt_inputs. If so, see if __call__ has a default
+ # for that arg and use it if available.
+ for required_param_name in required_params:
+ if required_param_name not in encode_prompt_inputs:
+ pipe_call_param = pipe_call_parameters.get(required_param_name, None)
+ if pipe_call_param is not None and pipe_call_param.default is not inspect.Parameter.empty:
+ # Use the default from pipe.__call__
+ encode_prompt_inputs[required_param_name] = pipe_call_param.default
+ elif extra_required_param_value_dict is not None and isinstance(extra_required_param_value_dict, dict):
+ encode_prompt_inputs[required_param_name] = extra_required_param_value_dict[required_param_name]
+ else:
+ raise ValueError(
+ f"Required parameter '{required_param_name}' in "
+ f"encode_prompt has no default in either encode_prompt or __call__."
+ )
+
+ # Compute `encode_prompt()`.
+ with torch.no_grad():
+ encoded_prompt_outputs = pipe_with_just_text_encoder.encode_prompt(**encode_prompt_inputs)
+
+ # Programatically determine the reutrn names of `encode_prompt.`
+ ast_vistor = ReturnNameVisitor()
+ encode_prompt_tree = ast_vistor.get_ast_tree(cls=self.pipeline_class)
+ ast_vistor.visit(encode_prompt_tree)
+ prompt_embed_kwargs = ast_vistor.return_names
+ prompt_embeds_kwargs = dict(zip(prompt_embed_kwargs, encoded_prompt_outputs))
+
+ # Pack the outputs of `encode_prompt`.
+ adapted_prompt_embeds_kwargs = {
+ k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
+ }
+
+ # now initialize a pipeline without text encoders and compute outputs with the
+ # `encode_prompt()` outputs and other relevant inputs.
+ components_with_text_encoders = {}
+ for k in components:
+ if "text" in k or "tokenizer" in k:
+ components_with_text_encoders[k] = None
+ else:
+ components_with_text_encoders[k] = components[k]
+ pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
+
+ # Set `negative_prompt` to None as we have already calculated its embeds
+ # if it was present in `inputs`. This is because otherwise we will interfere wrongly
+ # for non-None `negative_prompt` values as defaults (PixArt for example).
+ pipe_without_tes_inputs = {**inputs, **adapted_prompt_embeds_kwargs}
+ if (
+ pipe_call_parameters.get("negative_prompt", None) is not None
+ and pipe_call_parameters.get("negative_prompt").default is not None
+ ):
+ pipe_without_tes_inputs.update({"negative_prompt": None})
+
+ # Pipelines like attend and excite have `prompt` as a required argument.
+ if (
+ pipe_call_parameters.get("prompt", None) is not None
+ and pipe_call_parameters.get("prompt").default is inspect.Parameter.empty
+ and pipe_call_parameters.get("prompt_embeds", None) is not None
+ and pipe_call_parameters.get("prompt_embeds").default is None
+ ):
+ pipe_without_tes_inputs.update({"prompt": None})
+
+ pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0]
+
+ # Compare against regular pipeline outputs.
+ full_pipe = self.pipeline_class(**components).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ pipe_out_2 = full_pipe(**inputs)[0]
+
+ if isinstance(pipe_out, np.ndarray) and isinstance(pipe_out_2, np.ndarray):
+ self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))
+ elif isinstance(pipe_out, torch.Tensor) and isinstance(pipe_out_2, torch.Tensor):
+ self.assertTrue(torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))
+
def test_StableDiffusionMixin_component(self):
"""Any pipeline that have LDMFuncMixin should have vae and unet components."""
if not issubclass(self.pipeline_class, StableDiffusionMixin):
@@ -1907,6 +2162,150 @@ def test_StableDiffusionMixin_component(self):
)
)
+ @require_hf_hub_version_greater("0.26.5")
+ @require_transformers_version_greater("4.47.1")
+ def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):
+ if not self.supports_dduf:
+ return
+
+ from huggingface_hub import export_folder_as_dduf
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device="cpu")
+ inputs.pop("generator")
+ inputs["generator"] = torch.manual_seed(0)
+
+ pipeline_out = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
+ pipe.save_pretrained(tmpdir, safe_serialization=True)
+ export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
+
+ inputs["generator"] = torch.manual_seed(0)
+ loaded_pipeline_out = loaded_pipe(**inputs)[0]
+
+ if isinstance(pipeline_out, np.ndarray) and isinstance(loaded_pipeline_out, np.ndarray):
+ assert np.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol)
+ elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor):
+ assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol)
+
+ def test_layerwise_casting_inference(self):
+ if not self.test_layerwise_casting:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device, dtype=torch.bfloat16)
+ pipe.set_progress_bar_config(disable=None)
+
+ denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
+ denoiser.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ _ = pipe(**inputs)[0]
+
+ @require_torch_gpu
+ def test_group_offloading_inference(self):
+ if not self.test_group_offloading:
+ return
+
+ def create_pipe():
+ torch.manual_seed(0)
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ return pipe
+
+ def enable_group_offload_on_component(pipe, group_offloading_kwargs):
+ # We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
+ # tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of
+ # the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
+ # warmup forward pass (even with dummy small inputs) is recommended.
+ for component_name in [
+ "text_encoder",
+ "text_encoder_2",
+ "text_encoder_3",
+ "transformer",
+ "unet",
+ "controlnet",
+ ]:
+ if not hasattr(pipe, component_name):
+ continue
+ component = getattr(pipe, component_name)
+ if not getattr(component, "_supports_group_offloading", True):
+ continue
+ if hasattr(component, "enable_group_offload"):
+ # For diffusers ModelMixin implementations
+ component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs)
+ else:
+ # For other models not part of diffusers
+ apply_group_offloading(
+ component, onload_device=torch.device(torch_device), **group_offloading_kwargs
+ )
+ self.assertTrue(
+ all(
+ module._diffusers_hook.get_hook("group_offloading") is not None
+ for module in component.modules()
+ if hasattr(module, "_diffusers_hook")
+ )
+ )
+ for component_name in ["vae", "vqvae"]:
+ if hasattr(pipe, component_name):
+ getattr(pipe, component_name).to(torch_device)
+
+ def run_forward(pipe):
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(torch_device)
+ return pipe(**inputs)[0]
+
+ pipe = create_pipe().to(torch_device)
+ output_without_group_offloading = run_forward(pipe)
+
+ pipe = create_pipe()
+ enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
+ output_with_group_offloading1 = run_forward(pipe)
+
+ pipe = create_pipe()
+ enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
+ output_with_group_offloading2 = run_forward(pipe)
+
+ if torch.is_tensor(output_without_group_offloading):
+ output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
+ output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
+ output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()
+
+ self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
+ self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
+
+ def test_torch_dtype_dict(self):
+ components = self.get_dummy_components()
+ if not components:
+ self.skipTest("No dummy components defined.")
+
+ pipe = self.pipeline_class(**components)
+
+ specified_key = next(iter(components.keys()))
+
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
+ pipe.save_pretrained(tmpdirname, safe_serialization=False)
+ torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
+
+ for name, component in loaded_pipe.components.items():
+ if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
+ expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
+ self.assertEqual(
+ component.dtype,
+ expected_dtype,
+ f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
+ )
+
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
@@ -2043,148 +2442,300 @@ def test_push_to_hub_library_name(self):
delete_repo(self.repo_id, token=TOKEN)
-# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders
-# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`
-# test for all such pipelines. This requires us to use a custom `encode_prompt()` function.
-class SDXLOptionalComponentsTesterMixin:
- def encode_prompt(
- self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None
- ):
- device = text_encoders[0].device
-
- if isinstance(prompt, str):
- prompt = [prompt]
- batch_size = len(prompt)
-
- prompt_embeds_list = []
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
+class PyramidAttentionBroadcastTesterMixin:
+ pab_config = PyramidAttentionBroadcastConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(100, 800),
+ spatial_attention_block_identifiers=["transformer_blocks"],
+ )
- text_input_ids = text_inputs.input_ids
+ def test_pyramid_attention_broadcast_layers(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
- pooled_prompt_embeds = prompt_embeds[0]
- prompt_embeds = prompt_embeds.hidden_states[-2]
- prompt_embeds_list.append(prompt_embeds)
+ num_layers = 0
+ num_single_layers = 0
+ dummy_component_kwargs = {}
+ dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
+ if "num_layers" in dummy_component_parameters:
+ num_layers = 2
+ dummy_component_kwargs["num_layers"] = num_layers
+ if "num_single_layers" in dummy_component_parameters:
+ num_single_layers = 2
+ dummy_component_kwargs["num_single_layers"] = num_single_layers
+
+ components = self.get_dummy_components(**dummy_component_kwargs)
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ self.pab_config.current_timestep_callback = lambda: pipe.current_timestep
+ denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
+ denoiser.enable_cache(self.pab_config)
+
+ expected_hooks = 0
+ if self.pab_config.spatial_attention_block_skip_range is not None:
+ expected_hooks += num_layers + num_single_layers
+ if self.pab_config.temporal_attention_block_skip_range is not None:
+ expected_hooks += num_layers + num_single_layers
+ if self.pab_config.cross_attention_block_skip_range is not None:
+ expected_hooks += num_layers + num_single_layers
+
+ denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
+ count = 0
+ for module in denoiser.modules():
+ if hasattr(module, "_diffusers_hook"):
+ hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
+ if hook is None:
+ continue
+ count += 1
+ self.assertTrue(
+ isinstance(hook, PyramidAttentionBroadcastHook),
+ "Hook should be of type PyramidAttentionBroadcastHook.",
+ )
+ self.assertTrue(hook.state.cache is None, "Cache should be None at initialization.")
+ self.assertEqual(count, expected_hooks, "Number of hooks should match the expected number.")
+
+ # Perform dummy inference step to ensure state is updated
+ def pab_state_check_callback(pipe, i, t, kwargs):
+ for module in denoiser.modules():
+ if hasattr(module, "_diffusers_hook"):
+ hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
+ if hook is None:
+ continue
+ self.assertTrue(
+ hook.state.cache is not None,
+ "Cache should have updated during inference.",
+ )
+ self.assertTrue(
+ hook.state.iteration == i + 1,
+ "Hook iteration state should have updated during inference.",
+ )
+ return {}
- if negative_prompt is None:
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
- else:
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
-
- negative_prompt_embeds_list = []
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
- uncond_input = tokenizer(
- negative_prompt,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 2
+ inputs["callback_on_step_end"] = pab_state_check_callback
+ pipe(**inputs)[0]
+
+ # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
+ for module in denoiser.modules():
+ if hasattr(module, "_diffusers_hook"):
+ hook = module._diffusers_hook.get_hook("pyramid_attention_broadcast")
+ if hook is None:
+ continue
+ self.assertTrue(
+ hook.state.cache is None,
+ "Cache should be reset to None after inference.",
+ )
+ self.assertTrue(
+ hook.state.iteration == 0,
+ "Iteration should be reset to 0 after inference.",
)
- negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True)
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
- negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
- negative_prompt_embeds_list.append(negative_prompt_embeds)
+ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2):
+ # We need to use higher tolerance because we are using a random model. With a converged/trained
+ # model, the tolerance can be lower.
- negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ num_layers = 2
+ components = self.get_dummy_components(num_layers=num_layers)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
- bs_embed, seq_len, _ = prompt_embeds.shape
+ # Run inference without PAB
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ output = pipe(**inputs)[0]
+ original_image_slice = output.flatten()
+ original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:]))
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ # Run inference with PAB enabled
+ self.pab_config.current_timestep_callback = lambda: pipe.current_timestep
+ denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
+ denoiser.enable_cache(self.pab_config)
- # for classifier-free guidance
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
- seq_len = negative_prompt_embeds.shape[1]
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ output = pipe(**inputs)[0]
+ image_slice_pab_enabled = output.flatten()
+ image_slice_pab_enabled = np.concatenate((image_slice_pab_enabled[:8], image_slice_pab_enabled[-8:]))
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ # Run inference with PAB disabled
+ denoiser.disable_cache()
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
- bs_embed * num_images_per_prompt, -1
- )
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ output = pipe(**inputs)[0]
+ image_slice_pab_disabled = output.flatten()
+ image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
+
+ assert np.allclose(
+ original_image_slice, image_slice_pab_enabled, atol=expected_atol
+ ), "PAB outputs should not differ much in specified timestep range."
+ assert np.allclose(
+ original_image_slice, image_slice_pab_disabled, atol=1e-4
+ ), "Outputs from normal inference and after disabling cache should not differ."
- # for classifier-free guidance
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
- bs_embed * num_images_per_prompt, -1
- )
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+class FasterCacheTesterMixin:
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ )
- def _test_save_load_optional_components(self, expected_max_difference=1e-4):
+ def test_faster_cache_basic_warning_or_errors_raised(self):
components = self.get_dummy_components()
+ logger = logging.get_logger("diffusers.hooks.faster_cache")
+ logger.setLevel(logging.INFO)
+
+ # Check if warning is raise when no attention_weight_callback is provided
pipe = self.pipeline_class(**components)
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
+ with CaptureLogger(logger) as cap_logger:
+ config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None)
+ apply_faster_cache(pipe.transformer, config)
+ self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out)
- for component in pipe.components.values():
- if hasattr(component, "set_default_attn_processor"):
- component.set_default_attn_processor()
- pipe.to(torch_device)
+ # Check if error raised when unsupported tensor format used
+ pipe = self.pipeline_class(**components)
+ with self.assertRaises(ValueError):
+ config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC")
+ apply_faster_cache(pipe.transformer, config)
+
+ def test_faster_cache_inference(self, expected_atol: float = 0.1):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ def create_pipe():
+ torch.manual_seed(0)
+ num_layers = 2
+ components = self.get_dummy_components(num_layers=num_layers)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ return pipe
+
+ def run_forward(pipe):
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ return pipe(**inputs)[0]
+
+ # Run inference without FasterCache
+ pipe = create_pipe()
+ output = run_forward(pipe).flatten()
+ original_image_slice = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FasterCache enabled
+ self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
+ pipe = create_pipe()
+ pipe.transformer.enable_cache(self.faster_cache_config)
+ output = run_forward(pipe).flatten().flatten()
+ image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FasterCache disabled
+ pipe.transformer.disable_cache()
+ output = run_forward(pipe).flatten()
+ image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
+
+ assert np.allclose(
+ original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
+ ), "FasterCache outputs should not differ much in specified timestep range."
+ assert np.allclose(
+ original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
+ ), "Outputs from normal inference and after disabling cache should not differ."
+
+ def test_faster_cache_state(self):
+ from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
+
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ num_layers = 0
+ num_single_layers = 0
+ dummy_component_kwargs = {}
+ dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters
+ if "num_layers" in dummy_component_parameters:
+ num_layers = 2
+ dummy_component_kwargs["num_layers"] = num_layers
+ if "num_single_layers" in dummy_component_parameters:
+ num_single_layers = 2
+ dummy_component_kwargs["num_single_layers"] = num_single_layers
+
+ components = self.get_dummy_components(**dummy_component_kwargs)
+ pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
+ self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
+ pipe.transformer.enable_cache(self.faster_cache_config)
- tokenizer = components.pop("tokenizer")
- tokenizer_2 = components.pop("tokenizer_2")
- text_encoder = components.pop("text_encoder")
- text_encoder_2 = components.pop("text_encoder_2")
-
- tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2]
- text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2]
- prompt = inputs.pop("prompt")
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = self.encode_prompt(tokenizers, text_encoders, prompt)
- inputs["prompt_embeds"] = prompt_embeds
- inputs["negative_prompt_embeds"] = negative_prompt_embeds
- inputs["pooled_prompt_embeds"] = pooled_prompt_embeds
- inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
+ expected_hooks = 0
+ if self.faster_cache_config.spatial_attention_block_skip_range is not None:
+ expected_hooks += num_layers + num_single_layers
+ if self.faster_cache_config.temporal_attention_block_skip_range is not None:
+ expected_hooks += num_layers + num_single_layers
- output = pipe(**inputs)[0]
+ # Check if faster_cache denoiser hook is attached
+ denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
+ self.assertTrue(
+ hasattr(denoiser, "_diffusers_hook")
+ and isinstance(denoiser._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK), FasterCacheDenoiserHook),
+ "Hook should be of type FasterCacheDenoiserHook.",
+ )
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
- for component in pipe_loaded.components.values():
- if hasattr(component, "set_default_attn_processor"):
- component.set_default_attn_processor()
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
+ # Check if all blocks have faster_cache block hook attached
+ count = 0
+ for name, module in denoiser.named_modules():
+ if hasattr(module, "_diffusers_hook"):
+ if name == "":
+ # Skip the root denoiser module
+ continue
+ count += 1
+ self.assertTrue(
+ isinstance(module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK), FasterCacheBlockHook),
+ "Hook should be of type FasterCacheBlockHook.",
+ )
+ self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.")
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
+ # Perform inference to ensure that states are updated correctly
+ def faster_cache_state_check_callback(pipe, i, t, kwargs):
+ for name, module in denoiser.named_modules():
+ if not hasattr(module, "_diffusers_hook"):
+ continue
+ if name == "":
+ # Root denoiser module
+ state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state
+ if not self.faster_cache_config.is_guidance_distilled:
+ self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.")
+ self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.")
+ else:
+ # Internal blocks
+ state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state
+ self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.")
+ self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.")
+ return {}
- inputs = self.get_dummy_inputs(generator_device)
- _ = inputs.pop("prompt")
- inputs["prompt_embeds"] = prompt_embeds
- inputs["negative_prompt_embeds"] = negative_prompt_embeds
- inputs["pooled_prompt_embeds"] = pooled_prompt_embeds
- inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ inputs["callback_on_step_end"] = faster_cache_state_check_callback
+ _ = pipe(**inputs)[0]
- output_loaded = pipe_loaded(**inputs)[0]
+ # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states
+ for name, module in denoiser.named_modules():
+ if not hasattr(module, "_diffusers_hook"):
+ continue
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(max_diff, expected_max_difference)
+ if name == "":
+ # Root denoiser module
+ state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state
+ self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
+ self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.")
+ self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.")
+ else:
+ # Internal blocks
+ state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state
+ self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.")
+ self.assertTrue(state.batch_size is None, "Batch size should be reset to None.")
+ self.assertTrue(state.cache is None, "Cache should be reset to None.")
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
index bca4fdbfae64..5d0f8299f68e 100644
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
+++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
@@ -23,10 +23,11 @@
from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoSDPipeline, UNet3DConditionModel
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_numpy,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
slow,
torch_device,
@@ -173,22 +174,30 @@ def test_inference_batch_single_identical(self):
def test_num_images_per_prompt(self):
pass
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@slow
@skip_mps
-@require_torch_gpu
+@require_torch_accelerator
class TextToVideoSDPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_two_step_model(self):
expected_video = load_numpy(
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
index 8bef0cede154..db24767b60fc 100644
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
+++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
@@ -23,8 +23,14 @@
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoZeroSDXLPipeline, UNet2DConditionModel
-from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ nightly,
+ require_accelerate_version_greater,
+ require_accelerator,
+ require_torch_gpu,
+ torch_device,
+)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin
@@ -213,7 +219,8 @@ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
self.assertLess(max_diff, expected_max_difference)
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_float16_inference(self, expected_max_diff=5e-2):
components = self.get_dummy_components()
for name, module in components.items():
@@ -255,10 +262,8 @@ def test_inference_batch_consistent(self):
def test_inference_batch_single_identical(self):
pass
- @unittest.skipIf(
- torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
- reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
- )
+ @require_accelerator
+ @require_accelerate_version_greater("0.17.0")
def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -268,7 +273,7 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
inputs = self.get_dummy_inputs(self.generator_device)
output_without_offload = pipe(**inputs)[0]
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(self.generator_device)
output_with_offload = pipe(**inputs)[0]
@@ -279,7 +284,8 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
def test_pipeline_call_signature(self):
pass
- @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
@@ -331,7 +337,7 @@ def test_save_load_optional_components(self):
def test_sequential_cpu_offload_forward_pass(self):
pass
- @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ @require_accelerator
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -344,12 +350,12 @@ def test_to_device(self):
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
self.assertTrue(np.isnan(output_cpu).sum() == 0)
- pipe.to("cuda")
+ pipe.to(torch_device)
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
- self.assertTrue(all(device == "cuda" for device in model_devices))
+ self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
@unittest.skip(
reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
index 34ccb09e2204..f44a8aa33c5a 100644
--- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
+++ b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
@@ -197,6 +197,14 @@ def test_inference_batch_single_identical(self):
def test_num_images_per_prompt(self):
pass
+ def test_encode_prompt_works_in_isolation(self):
+ extra_required_param_value_dict = {
+ "device": torch.device(torch_device).type,
+ "num_images_per_prompt": 1,
+ "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
+ }
+ return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
+
@nightly
@skip_mps
diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py
index 07590c9db458..26a1bead0138 100644
--- a/tests/pipelines/unclip/test_unclip.py
+++ b/tests/pipelines/unclip/test_unclip.py
@@ -303,6 +303,7 @@ class DummyScheduler:
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
+ generator = torch.Generator(device=device).manual_seed(0)
decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py
index dfc3acc0c0f2..e402629fe1b9 100644
--- a/tests/pipelines/unclip/test_unclip_image_variation.py
+++ b/tests/pipelines/unclip/test_unclip_image_variation.py
@@ -66,6 +66,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
"super_res_num_inference_steps",
]
test_xformers_attention = False
+ supports_dduf = False
@property
def text_embedder_hidden_size(self):
@@ -406,6 +407,7 @@ class DummyScheduler:
pipe.super_res_first.config.sample_size,
pipe.super_res_first.config.sample_size,
)
+ generator = torch.Generator(device=device).manual_seed(0)
super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py
index 2e0ba1cfb8eb..292978eb6eee 100644
--- a/tests/pipelines/unidiffuser/test_unidiffuser.py
+++ b/tests/pipelines/unidiffuser/test_unidiffuser.py
@@ -27,6 +27,7 @@
load_image,
nightly,
require_torch_2,
+ require_torch_accelerator,
require_torch_gpu,
run_test_in_subprocess,
torch_device,
@@ -86,6 +87,8 @@ class UniDiffuserPipelineFastTests(
# vae_latents, not latents, is the argument that corresponds to VAE latent inputs
image_latents_params = frozenset(["vae_latents"])
+ supports_dduf = False
+
def get_dummy_components(self):
unet = UniDiffuserModel.from_pretrained(
"hf-internal-testing/unidiffuser-diffusers-test",
@@ -499,20 +502,19 @@ def test_unidiffuser_img2text_multiple_prompts_with_latents(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=2e-4)
- @require_torch_gpu
- def test_unidiffuser_default_joint_v1_cuda_fp16(self):
- device = "cuda"
+ @require_torch_accelerator
+ def test_unidiffuser_default_joint_v1_fp16(self):
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
"hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
+ unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'joint'
unidiffuser_pipe.set_joint_mode()
assert unidiffuser_pipe.mode == "joint"
- inputs = self.get_dummy_inputs_with_latents(device)
+ inputs = self.get_dummy_inputs_with_latents(torch_device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
@@ -529,20 +531,19 @@ def test_unidiffuser_default_joint_v1_cuda_fp16(self):
expected_text_prefix = '" This This'
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
- @require_torch_gpu
- def test_unidiffuser_default_text2img_v1_cuda_fp16(self):
- device = "cuda"
+ @require_torch_accelerator
+ def test_unidiffuser_default_text2img_v1_fp16(self):
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
"hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
+ unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text2img'
unidiffuser_pipe.set_text_to_image_mode()
assert unidiffuser_pipe.mode == "text2img"
- inputs = self.get_dummy_inputs_with_latents(device)
+ inputs = self.get_dummy_inputs_with_latents(torch_device)
# Delete prompt and image for joint inference.
del inputs["image"]
inputs["data_type"] = 1
@@ -554,20 +555,19 @@ def test_unidiffuser_default_text2img_v1_cuda_fp16(self):
expected_img_slice = np.array([0.5054, 0.5498, 0.5854, 0.3052, 0.4458, 0.6489, 0.5122, 0.4810, 0.6138])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
- @require_torch_gpu
- def test_unidiffuser_default_img2text_v1_cuda_fp16(self):
- device = "cuda"
+ @require_torch_accelerator
+ def test_unidiffuser_default_img2text_v1_fp16(self):
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
"hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
+ unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img2text'
unidiffuser_pipe.set_image_to_text_mode()
assert unidiffuser_pipe.mode == "img2text"
- inputs = self.get_dummy_inputs_with_latents(device)
+ inputs = self.get_dummy_inputs_with_latents(torch_device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
inputs["data_type"] = 1
@@ -576,6 +576,12 @@ def test_unidiffuser_default_img2text_v1_cuda_fp16(self):
expected_text_prefix = '" This This'
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
+ @unittest.skip(
+ "Test not supported becauseit has a bunch of direct configs at init and also, this pipeline isn't used that much now."
+ )
+ def test_encode_prompt_works_in_isolation():
+ pass
+
@nightly
@require_torch_gpu
diff --git a/tests/pipelines/wan/__init__.py b/tests/pipelines/wan/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py
new file mode 100644
index 000000000000..a162e6841d2d
--- /dev/null
+++ b/tests/pipelines/wan/test_wan.py
@@ -0,0 +1,156 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+
+@slow
+@require_torch_accelerator
+class WanPipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_Wanx(self):
+ pass
diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py
new file mode 100644
index 000000000000..53fa37dfae99
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_image_to_video.py
@@ -0,0 +1,162 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+)
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanTransformer3DModel
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=32,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=32, size=32)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py
new file mode 100644
index 000000000000..11c748424a30
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_video_to_video.py
@@ -0,0 +1,146 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanTransformer3DModel, WanVideoToVideoPipeline
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanVideoToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["video", "prompt", "negative_prompt"])
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=3.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ video = [Image.new("RGB", (16, 16))] * 17
+ inputs = {
+ "video": video,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 4,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (17, 3, 16, 16))
+ expected_video = torch.randn(17, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "WanVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors"
+ )
+ def test_float16_inference(self):
+ pass
+
+ @unittest.skip(
+ "WanVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
+ )
+ def test_save_load_float16(self):
+ pass
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
index 0caed159100a..084d62a8c613 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
@@ -21,7 +21,7 @@
from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
@@ -198,7 +198,7 @@ def test_wuerstchen(self):
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- @require_torch_gpu
+ @require_torch_accelerator
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
@@ -207,12 +207,12 @@ def test_offloads(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload()
+ sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload()
+ sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe)
image_slices = []
@@ -232,8 +232,10 @@ def test_inference_batch_single_identical(self):
def test_float16_inference(self):
super().test_float16_inference()
+ @unittest.skip(reason="Test not supported.")
def test_callback_inputs(self):
pass
+ @unittest.skip(reason="Test not supported.")
def test_callback_cfg(self):
pass
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
index 467550138790..97d1a1cc3830 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
@@ -186,3 +186,7 @@ def test_attention_slicing_forward_pass(self):
@unittest.skip(reason="bf16 not supported and requires CUDA")
def test_float16_inference(self):
super().test_float16_inference()
+
+ @unittest.skip("Test not supoorted.")
+ def test_encode_prompt_works_in_isolation(self):
+ super().test_encode_prompt_works_in_isolation()
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
index 460004da6f04..4bc086e7f65b 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
@@ -267,3 +267,7 @@ def test_inference_with_prior_lora(self):
lora_image_embed = output_lora.image_embeddings
self.assertTrue(image_embed.shape == lora_image_embed.shape)
+
+ @unittest.skip("Test not supported as dtype cannot be inferred without the text encoder otherwise.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index 6c1b24e31e2a..29a3e212c48d 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
+import os
import tempfile
import unittest
import numpy as np
+import pytest
+import safetensors.torch
+from huggingface_hub import hf_hub_download
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
-from diffusers.utils import logging
+from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
+ backend_empty_cache,
is_bitsandbytes_available,
is_torch_available,
is_transformers_available,
@@ -29,8 +34,9 @@
numpy_cosine_similarity_distance,
require_accelerate,
require_bitsandbytes_version_greater,
+ require_peft_backend,
require_torch,
- require_torch_gpu,
+ require_torch_accelerator,
require_transformers_version_greater,
slow,
torch_device,
@@ -45,33 +51,13 @@ def get_some_linear_layer(model):
if is_transformers_available():
+ from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel
if is_torch_available():
import torch
- import torch.nn as nn
- class LoRALayer(nn.Module):
- """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
-
- Taken from
- https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
- """
-
- def __init__(self, module: nn.Module, rank: int):
- super().__init__()
- self.module = module
- self.adapter = nn.Sequential(
- nn.Linear(module.in_features, rank, bias=False),
- nn.Linear(rank, module.out_features, bias=False),
- )
- small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
- nn.init.normal_(self.adapter[0].weight, std=small_std)
- nn.init.zeros_(self.adapter[1].weight)
- self.adapter.to(module.weight.device)
-
- def forward(self, input, *args, **kwargs):
- return self.module(input, *args, **kwargs) + self.adapter(input)
+ from ..utils import LoRALayer, get_memory_consumption_stat
if is_bitsandbytes_available():
@@ -81,7 +67,7 @@ def forward(self, input, *args, **kwargs):
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@require_torch
-@require_torch_gpu
+@require_torch_accelerator
@slow
class Base4bitTests(unittest.TestCase):
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
@@ -91,19 +77,24 @@ class Base4bitTests(unittest.TestCase):
# This was obtained on audace so the number might slightly change
expected_rel_difference = 3.69
+ expected_memory_saving_ratio = 0.8
+
prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
def get_dummy_inputs(self):
prompt_embeds = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
+ torch_device,
)
pooled_prompt_embeds = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
+ torch_device,
)
latent_model_input = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
+ torch_device,
)
input_dict_for_transformer = {
@@ -118,6 +109,9 @@ def get_dummy_inputs(self):
class BnB4BitBasicTests(Base4bitTests):
def setUp(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
@@ -128,15 +122,17 @@ def setUp(self):
bnb_4bit_compute_dtype=torch.float16,
)
self.model_4bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=nf4_config
+ self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
def tearDown(self):
- del self.model_fp16
- del self.model_4bit
+ if hasattr(self, "model_fp16"):
+ del self.model_fp16
+ if hasattr(self, "model_4bit"):
+ del self.model_4bit
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_quantization_num_parameters(self):
r"""
@@ -172,6 +168,32 @@ def test_memory_footprint(self):
linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
+ def test_model_memory_usage(self):
+ # Delete to not let anything interfere.
+ del self.model_4bit, self.model_fp16
+
+ # Re-instantiate.
+ inputs = self.get_dummy_inputs()
+ inputs = {
+ k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
+ }
+ model_fp16 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", torch_dtype=torch.float16
+ ).to(torch_device)
+ unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
+ del model_fp16
+
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16
+ )
+ quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
+ assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
+
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
@@ -194,7 +216,7 @@ def test_keep_modules_in_fp32(self):
bnb_4bit_compute_dtype=torch.float16,
)
model = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=nf4_config
+ self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
for name, module in model.named_modules():
@@ -206,7 +228,7 @@ def test_keep_modules_in_fp32(self):
self.assertTrue(module.weight.dtype == torch.uint8)
# test if inference works.
- with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
+ with torch.no_grad() and torch.amp.autocast(torch_device, dtype=torch.float16):
input_dict_for_transformer = self.get_dummy_inputs()
model_inputs = {
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
@@ -232,7 +254,7 @@ def test_linear_are_4bit(self):
def test_config_from_pretrained(self):
transformer_4bit = FluxTransformer2DModel.from_pretrained(
- "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
+ "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
linear = get_some_linear_layer(transformer_4bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
@@ -248,9 +270,9 @@ def test_device_assignment(self):
self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
# Move back to CUDA device
- for device in [0, "cuda", "cuda:0", "call()"]:
+ for device in [0, f"{torch_device}", f"{torch_device}:0", "call()"]:
if device == "call()":
- self.model_4bit.cuda(0)
+ self.model_4bit.to(f"{torch_device}:0")
else:
self.model_4bit.to(device)
self.assertEqual(self.model_4bit.device, torch.device(0))
@@ -268,7 +290,7 @@ def test_device_and_dtype_assignment(self):
with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
- self.model_4bit.to(device="cuda:0", dtype=torch.float16)
+ self.model_4bit.to(device=f"{torch_device}:0", dtype=torch.float16)
with self.assertRaises(ValueError):
# Tries with a cast
@@ -279,7 +301,7 @@ def test_device_and_dtype_assignment(self):
self.model_4bit.half()
# This should work
- self.model_4bit.to("cuda")
+ self.model_4bit.to(torch_device)
# Test if we did not break anything
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
@@ -303,7 +325,7 @@ def test_device_and_dtype_assignment(self):
_ = self.model_fp16.float()
# Check that this does not throw an error
- _ = self.model_fp16.cuda()
+ _ = self.model_fp16.to(torch_device)
def test_bnb_4bit_wrong_config(self):
r"""
@@ -312,16 +334,49 @@ def test_bnb_4bit_wrong_config(self):
with self.assertRaises(ValueError):
_ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")
+ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
+ r"""
+ Test if loading with an incorrect state dict raises an error.
+ """
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True)
+ model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
+ )
+ model_4bit.save_pretrained(tmpdirname)
+ del model_4bit
+
+ with self.assertRaises(ValueError) as err_context:
+ state_dict = safetensors.torch.load_file(
+ os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
+ )
+
+ # corrupt the state dict
+ key_to_target = "context_embedder.weight" # can be other keys too.
+ compatible_param = state_dict[key_to_target]
+ corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1)
+ state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False)
+ safetensors.torch.save_file(
+ state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
+ )
+
+ _ = SD3Transformer2DModel.from_pretrained(tmpdirname)
+
+ assert key_to_target in str(err_context.exception)
+
class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
self.model_4bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=nf4_config
+ self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
def test_training(self):
@@ -347,7 +402,7 @@ def test_training(self):
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
# Step 4: Check if the gradient is not None
- with torch.amp.autocast("cuda", dtype=torch.float16):
+ with torch.amp.autocast(torch_device, dtype=torch.float16):
out = self.model_4bit(**model_inputs)[0]
out.norm().backward()
@@ -360,13 +415,16 @@ def test_training(self):
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitTests(Base4bitTests):
def setUp(self) -> None:
+ gc.collect()
+ backend_empty_cache(torch_device)
+
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=nf4_config
+ self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_4bit, torch_dtype=torch.float16
@@ -377,7 +435,7 @@ def tearDown(self):
del self.pipeline_4bit
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_quality(self):
output = self.pipeline_4bit(
@@ -391,7 +449,6 @@ def test_quality(self):
expected_slice = np.array([0.1123, 0.1296, 0.1609, 0.1042, 0.1230, 0.1274, 0.0928, 0.1165, 0.1216])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
- print(f"{max_diff=}")
self.assertTrue(max_diff < 1e-2)
def test_generate_quality_dequantize(self):
@@ -429,7 +486,7 @@ def test_moving_to_cpu_throws_warning(self):
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=nf4_config
+ self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device
)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
@@ -443,12 +500,145 @@ def test_moving_to_cpu_throws_warning(self):
assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
+ @pytest.mark.xfail(
+ condition=is_accelerate_version("<=", "1.1.1"),
+ reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
+ strict=True,
+ )
+ def test_pipeline_device_placement_works_with_nf4(self):
+ transformer_nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ transformer_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name,
+ subfolder="transformer",
+ quantization_config=transformer_nf4_config,
+ torch_dtype=torch.float16,
+ device_map=torch_device,
+ )
+ text_encoder_3_nf4_config = BnbConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ text_encoder_3_4bit = T5EncoderModel.from_pretrained(
+ self.model_name,
+ subfolder="text_encoder_3",
+ quantization_config=text_encoder_3_nf4_config,
+ torch_dtype=torch.float16,
+ device_map=torch_device,
+ )
+ # CUDA device placement works.
+ pipeline_4bit = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ transformer=transformer_4bit,
+ text_encoder_3=text_encoder_3_4bit,
+ torch_dtype=torch.float16,
+ ).to(torch_device)
+
+ # Check if inference works.
+ _ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
+
+ del pipeline_4bit
+
+ def test_device_map(self):
+ """
+ Test if the quantized model is working properly with "auto".
+ cpu/disk offloading as well doesn't work with bnb.
+ """
+
+ def get_dummy_tensor_inputs(device=None, seed: int = 0):
+ batch_size = 1
+ num_latent_channels = 4
+ num_image_channels = 3
+ height = width = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ torch.manual_seed(seed)
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
+ device, dtype=torch.bfloat16
+ )
+ torch.manual_seed(seed)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
+ device, dtype=torch.bfloat16
+ )
+
+ torch.manual_seed(seed)
+ pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
+
+ timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_prompt_embeds,
+ "txt_ids": text_ids,
+ "img_ids": image_ids,
+ "timestep": timestep,
+ }
+
+ inputs = get_dummy_tensor_inputs(torch_device)
+ expected_slice = np.array(
+ [0.47070312, 0.00390625, -0.03662109, -0.19628906, -0.53125, 0.5234375, -0.17089844, -0.59375, 0.578125]
+ )
+
+ # non sharded
+ quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
+ )
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ )
+
+ weight = quantized_model.transformer_blocks[0].ff.net[2].weight
+ self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
+
+ output = quantized_model(**inputs)[0]
+ output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+
+ # sharded
+
+ quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
+ )
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-sharded",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ )
+
+ weight = quantized_model.transformer_blocks[0].ff.net[2].weight
+ self.assertTrue(isinstance(weight, bnb.nn.modules.Params4bit))
+
+ output = quantized_model(**inputs)[0]
+ output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
+
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):
def setUp(self) -> None:
- # TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo.
- model_id = "sayakpaul/flux.1-dev-nf4-pkg"
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
@@ -483,12 +673,34 @@ def test_quality(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
+ @require_peft_backend
+ def test_lora_loading(self):
+ self.pipeline_4bit.load_lora_weights(
+ hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
+ )
+ self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125)
+
+ output = self.pipeline_4bit(
+ prompt=self.prompt,
+ height=256,
+ width=256,
+ max_sequence_length=64,
+ output_type="np",
+ num_inference_steps=8,
+ generator=torch.Generator().manual_seed(42),
+ ).images
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
r"""
@@ -503,7 +715,10 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_0 = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=self.quantization_config
+ self.model_name,
+ subfolder="transformer",
+ quantization_config=self.quantization_config,
+ device_map=torch_device,
)
self.assertTrue("_pre_quantization_dtype" in model_0.config)
with tempfile.TemporaryDirectory() as tmpdirname:
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index 2e4aec39b427..8809bac25f58 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -17,10 +17,21 @@
import unittest
import numpy as np
-
-from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
+import pytest
+from huggingface_hub import hf_hub_download
+
+from diffusers import (
+ BitsAndBytesConfig,
+ DiffusionPipeline,
+ FluxTransformer2DModel,
+ SanaTransformer2DModel,
+ SD3Transformer2DModel,
+ logging,
+)
+from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
+ backend_empty_cache,
is_bitsandbytes_available,
is_torch_available,
is_transformers_available,
@@ -28,8 +39,9 @@
numpy_cosine_similarity_distance,
require_accelerate,
require_bitsandbytes_version_greater,
+ require_peft_version_greater,
require_torch,
- require_torch_gpu,
+ require_torch_accelerator,
require_transformers_version_greater,
slow,
torch_device,
@@ -44,33 +56,13 @@ def get_some_linear_layer(model):
if is_transformers_available():
+ from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel
if is_torch_available():
import torch
- import torch.nn as nn
-
- class LoRALayer(nn.Module):
- """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
- Taken from
- https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
- """
-
- def __init__(self, module: nn.Module, rank: int):
- super().__init__()
- self.module = module
- self.adapter = nn.Sequential(
- nn.Linear(module.in_features, rank, bias=False),
- nn.Linear(rank, module.out_features, bias=False),
- )
- small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
- nn.init.normal_(self.adapter[0].weight, std=small_std)
- nn.init.zeros_(self.adapter[1].weight)
- self.adapter.to(module.weight.device)
-
- def forward(self, input, *args, **kwargs):
- return self.module(input, *args, **kwargs) + self.adapter(input)
+ from ..utils import LoRALayer, get_memory_consumption_stat
if is_bitsandbytes_available():
@@ -80,7 +72,7 @@ def forward(self, input, *args, **kwargs):
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@require_torch
-@require_torch_gpu
+@require_torch_accelerator
@slow
class Base8bitTests(unittest.TestCase):
# We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
@@ -90,19 +82,24 @@ class Base8bitTests(unittest.TestCase):
# This was obtained on audace so the number might slightly change
expected_rel_difference = 1.94
+ expected_memory_saving_ratio = 0.7
+
prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
def get_dummy_inputs(self):
prompt_embeds = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
+ map_location="cpu",
)
pooled_prompt_embeds = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
+ map_location="cpu",
)
latent_model_input = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
+ map_location="cpu",
)
input_dict_for_transformer = {
@@ -117,21 +114,26 @@ def get_dummy_inputs(self):
class BnB8bitBasicTests(Base8bitTests):
def setUp(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
)
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
def tearDown(self):
- del self.model_fp16
- del self.model_8bit
+ if hasattr(self, "model_fp16"):
+ del self.model_fp16
+ if hasattr(self, "model_8bit"):
+ del self.model_8bit
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_quantization_num_parameters(self):
r"""
@@ -167,6 +169,28 @@ def test_memory_footprint(self):
linear = get_some_linear_layer(self.model_8bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+ def test_model_memory_usage(self):
+ # Delete to not let anything interfere.
+ del self.model_8bit, self.model_fp16
+
+ # Re-instantiate.
+ inputs = self.get_dummy_inputs()
+ inputs = {
+ k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
+ }
+ model_fp16 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", torch_dtype=torch.float16
+ ).to(torch_device)
+ unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
+ del model_fp16
+
+ config = BitsAndBytesConfig(load_in_8bit=True)
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
+ )
+ quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
+ assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio
+
def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
@@ -185,7 +209,7 @@ def test_keep_modules_in_fp32(self):
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
for name, module in model.named_modules():
@@ -197,7 +221,7 @@ def test_keep_modules_in_fp32(self):
self.assertTrue(module.weight.dtype == torch.int8)
# test if inference works.
- with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
+ with torch.no_grad() and torch.autocast(model.device.type, dtype=torch.float16):
input_dict_for_transformer = self.get_dummy_inputs()
model_inputs = {
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
@@ -227,18 +251,18 @@ def test_llm_skip(self):
"""
config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
model_8bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=config
+ self.model_name, subfolder="transformer", quantization_config=config, device_map=torch_device
)
linear = get_some_linear_layer(model_8bit)
self.assertTrue(linear.weight.dtype == torch.int8)
self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))
- self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
+ self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear))
self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)
def test_config_from_pretrained(self):
transformer_8bit = FluxTransformer2DModel.from_pretrained(
- "sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer"
+ "hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer"
)
linear = get_some_linear_layer(transformer_8bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
@@ -259,7 +283,7 @@ def test_device_and_dtype_assignment(self):
with self.assertRaises(ValueError):
# Tries with a `device`
- self.model_8bit.to(torch.device("cuda:0"))
+ self.model_8bit.to(torch.device(f"{torch_device}:0"))
with self.assertRaises(ValueError):
# Tries with a `device`
@@ -291,14 +315,45 @@ def test_device_and_dtype_assignment(self):
_ = self.model_fp16.float()
# Check that this does not throw an error
- _ = self.model_fp16.cuda()
+ _ = self.model_fp16.to(torch_device)
+
+
+class Bnb8bitDeviceTests(Base8bitTests):
+ def setUp(self) -> None:
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ self.model_8bit = SanaTransformer2DModel.from_pretrained(
+ "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
+ subfolder="transformer",
+ quantization_config=mixed_int8_config,
+ device_map=torch_device,
+ )
+
+ def tearDown(self):
+ del self.model_8bit
+
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_buffers_device_assignment(self):
+ for buffer_name, buffer in self.model_8bit.named_buffers():
+ self.assertEqual(
+ buffer.device.type,
+ torch.device(torch_device).type,
+ f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
+ )
class BnB8bitTrainingTests(Base8bitTests):
def setUp(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
def test_training(self):
@@ -337,9 +392,12 @@ def test_training(self):
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitTests(Base8bitTests):
def setUp(self) -> None:
+ gc.collect()
+ backend_empty_cache(torch_device)
+
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config, device_map=torch_device
)
self.pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_8bit, torch_dtype=torch.float16
@@ -350,7 +408,7 @@ def tearDown(self):
del self.pipeline_8bit
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_quality(self):
output = self.pipeline_8bit(
@@ -360,14 +418,17 @@ def test_quality(self):
output_type="np",
).images
out_slice = output[0, -3:, -3:, -1].flatten()
- expected_slice = np.array([0.0149, 0.0322, 0.0073, 0.0134, 0.0332, 0.011, 0.002, 0.0232, 0.0193])
+ expected_slice = np.array([0.0674, 0.0623, 0.0364, 0.0632, 0.0671, 0.0430, 0.0317, 0.0493, 0.0583])
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-2)
def test_model_cpu_offload_raises_warning(self):
model_8bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
+ self.model_name,
+ subfolder="transformer",
+ quantization_config=BitsAndBytesConfig(load_in_8bit=True),
+ device_map=torch_device,
)
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name, transformer=model_8bit, torch_dtype=torch.float16
@@ -382,7 +443,10 @@ def test_model_cpu_offload_raises_warning(self):
def test_moving_to_cpu_throws_warning(self):
model_8bit = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
+ self.model_name,
+ subfolder="transformer",
+ quantization_config=BitsAndBytesConfig(load_in_8bit=True),
+ device_map=torch_device,
)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(30)
@@ -423,12 +487,142 @@ def test_generate_quality_dequantize(self):
output_type="np",
).images
+ @pytest.mark.xfail(
+ condition=is_accelerate_version("<=", "1.1.1"),
+ reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
+ strict=True,
+ )
+ def test_pipeline_cuda_placement_works_with_mixed_int8(self):
+ transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
+ transformer_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name,
+ subfolder="transformer",
+ quantization_config=transformer_8bit_config,
+ torch_dtype=torch.float16,
+ device_map=torch_device,
+ )
+ text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
+ text_encoder_3_8bit = T5EncoderModel.from_pretrained(
+ self.model_name,
+ subfolder="text_encoder_3",
+ quantization_config=text_encoder_3_8bit_config,
+ torch_dtype=torch.float16,
+ device_map=torch_device,
+ )
+ # CUDA device placement works.
+ pipeline_8bit = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ transformer=transformer_8bit,
+ text_encoder_3=text_encoder_3_8bit,
+ torch_dtype=torch.float16,
+ ).to("cuda")
+
+ # Check if inference works.
+ _ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
+
+ del pipeline_8bit
+
+ def test_device_map(self):
+ """
+ Test if the quantized model is working properly with "auto"
+ pu/disk offloading doesn't work with bnb.
+ """
+
+ def get_dummy_tensor_inputs(device=None, seed: int = 0):
+ batch_size = 1
+ num_latent_channels = 4
+ num_image_channels = 3
+ height = width = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ torch.manual_seed(seed)
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(
+ device, dtype=torch.bfloat16
+ )
+
+ torch.manual_seed(seed)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
+ device, dtype=torch.bfloat16
+ )
+
+ torch.manual_seed(seed)
+ pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
+
+ timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_prompt_embeds,
+ "txt_ids": text_ids,
+ "img_ids": image_ids,
+ "timestep": timestep,
+ }
+
+ inputs = get_dummy_tensor_inputs(torch_device)
+ expected_slice = np.array(
+ [
+ 0.33789062,
+ -0.04736328,
+ -0.00256348,
+ -0.23144531,
+ -0.49804688,
+ 0.4375,
+ -0.15429688,
+ -0.65234375,
+ 0.44335938,
+ ]
+ )
+
+ # non sharded
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ )
+
+ weight = quantized_model.transformer_blocks[0].ff.net[2].weight
+ self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
+
+ output = quantized_model(**inputs)[0]
+ output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+
+ # sharded
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-sharded",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ )
+
+ weight = quantized_model.transformer_blocks[0].ff.net[2].weight
+ self.assertTrue(isinstance(weight, bnb.nn.modules.Int8Params))
+ output = quantized_model(**inputs)[0]
+ output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
+
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
def setUp(self) -> None:
- # TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo.
- model_id = "sayakpaul/flux.1-dev-int8-pkg"
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
self.pipeline_8bit = DiffusionPipeline.from_pretrained(
@@ -443,7 +637,7 @@ def tearDown(self):
del self.pipeline_8bit
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_quality(self):
# keep the resolution and max tokens to a lower number for faster execution.
@@ -462,22 +656,48 @@ def test_quality(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)
+ @require_peft_version_greater("0.14.0")
+ def test_lora_loading(self):
+ self.pipeline_8bit.load_lora_weights(
+ hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
+ )
+ self.pipeline_8bit.set_adapters("hyper-sd", adapter_weights=0.125)
+
+ output = self.pipeline_8bit(
+ prompt=self.prompt,
+ height=256,
+ width=256,
+ max_sequence_length=64,
+ output_type="np",
+ num_inference_steps=8,
+ generator=torch.manual_seed(42),
+ ).images
+ out_slice = output[0, -3:, -3:, -1].flatten()
+
+ expected_slice = np.array([0.3916, 0.3916, 0.3887, 0.4243, 0.4155, 0.4233, 0.4570, 0.4531, 0.4248])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
@slow
class BaseBnb8bitSerializationTests(Base8bitTests):
def setUp(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
self.model_0 = SD3Transformer2DModel.from_pretrained(
- self.model_name, subfolder="transformer", quantization_config=quantization_config
+ self.model_name, subfolder="transformer", quantization_config=quantization_config, device_map=torch_device
)
def tearDown(self):
del self.model_0
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_serialization(self):
r"""
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
new file mode 100644
index 000000000000..5e3875c7c9cb
--- /dev/null
+++ b/tests/quantization/gguf/test_gguf.py
@@ -0,0 +1,458 @@
+import gc
+import unittest
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from diffusers import (
+ AuraFlowPipeline,
+ AuraFlowTransformer2DModel,
+ FluxPipeline,
+ FluxTransformer2DModel,
+ GGUFQuantizationConfig,
+ SD3Transformer2DModel,
+ StableDiffusion3Pipeline,
+)
+from diffusers.utils.testing_utils import (
+ is_gguf_available,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_accelerate,
+ require_big_gpu_with_torch_cuda,
+ require_gguf_version_greater_or_equal,
+ torch_device,
+)
+
+
+if is_gguf_available():
+ from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
+
+
+@nightly
+@require_big_gpu_with_torch_cuda
+@require_accelerate
+@require_gguf_version_greater_or_equal("0.10.0")
+class GGUFSingleFileTesterMixin:
+ ckpt_path = None
+ model_cls = None
+ torch_dtype = torch.bfloat16
+ expected_memory_use_in_gb = 5
+
+ def test_gguf_parameters(self):
+ quant_storage_type = torch.uint8
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
+
+ for param_name, param in model.named_parameters():
+ if isinstance(param, GGUFParameter):
+ assert hasattr(param, "quant_type")
+ assert param.dtype == quant_storage_type
+
+ def test_gguf_linear_layers(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
+ assert module.weight.dtype == torch.uint8
+ if module.bias is not None:
+ assert module.bias.dtype == self.torch_dtype
+
+ def test_gguf_memory_usage(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+
+ model = self.model_cls.from_single_file(
+ self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
+ )
+ model.to("cuda")
+ assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb
+ inputs = self.get_dummy_inputs()
+
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ model(**inputs)
+ max_memory = torch.cuda.max_memory_allocated()
+ assert (max_memory / 1024**3) < self.expected_memory_use_in_gb
+
+ def test_keep_modules_in_fp32(self):
+ r"""
+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
+ Also ensures if inference works.
+ """
+ _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
+ self.model_cls._keep_in_fp32_modules = ["proj_out"]
+
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ assert module.weight.dtype == torch.float32
+ self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
+
+ def test_dtype_assignment(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
+
+ with self.assertRaises(ValueError):
+ # Tries with a `dtype`
+ model.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device` and `dtype`
+ model.to(device="cuda:0", dtype=torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a cast
+ model.float()
+
+ with self.assertRaises(ValueError):
+ # Tries with a cast
+ model.half()
+
+ # This should work
+ model.to("cuda")
+
+ def test_dequantize_model(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
+ model.dequantize()
+
+ def _check_for_gguf_linear(model):
+ has_children = list(model.children())
+ if not has_children:
+ return
+
+ for name, module in model.named_children():
+ if isinstance(module, nn.Linear):
+ assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
+ assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
+
+ for name, module in model.named_children():
+ _check_for_gguf_linear(module)
+
+
+class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = FluxTransformer2DModel
+ expected_memory_use_in_gb = 5
+
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "pooled_projections": torch.randn(
+ (1, 768),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ "img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
+ }
+
+ def test_pipeline_inference(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ transformer = self.model_cls.from_single_file(
+ self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
+ )
+ pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
+ )
+ pipe.enable_model_cpu_offload()
+
+ prompt = "a cat holding a sign that says hello"
+ output = pipe(
+ prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
+ ).images[0]
+ output_slice = output[:3, :3, :].flatten()
+ expected_slice = np.array(
+ [
+ 0.47265625,
+ 0.43359375,
+ 0.359375,
+ 0.47070312,
+ 0.421875,
+ 0.34375,
+ 0.46875,
+ 0.421875,
+ 0.34765625,
+ 0.46484375,
+ 0.421875,
+ 0.34179688,
+ 0.47070312,
+ 0.42578125,
+ 0.34570312,
+ 0.46875,
+ 0.42578125,
+ 0.3515625,
+ 0.45507812,
+ 0.4140625,
+ 0.33984375,
+ 0.4609375,
+ 0.41796875,
+ 0.34375,
+ 0.45898438,
+ 0.41796875,
+ 0.34375,
+ ]
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
+ assert max_diff < 1e-4
+
+
+class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = SD3Transformer2DModel
+ expected_memory_use_in_gb = 5
+
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "pooled_projections": torch.randn(
+ (1, 2048),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+ def test_pipeline_inference(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ transformer = self.model_cls.from_single_file(
+ self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
+ )
+ pipe = StableDiffusion3Pipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-large", transformer=transformer, torch_dtype=self.torch_dtype
+ )
+ pipe.enable_model_cpu_offload()
+
+ prompt = "a cat holding a sign that says hello"
+ output = pipe(
+ prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
+ ).images[0]
+ output_slice = output[:3, :3, :].flatten()
+ expected_slice = np.array(
+ [
+ 0.17578125,
+ 0.27539062,
+ 0.27734375,
+ 0.11914062,
+ 0.26953125,
+ 0.25390625,
+ 0.109375,
+ 0.25390625,
+ 0.25,
+ 0.15039062,
+ 0.26171875,
+ 0.28515625,
+ 0.13671875,
+ 0.27734375,
+ 0.28515625,
+ 0.12109375,
+ 0.26757812,
+ 0.265625,
+ 0.16210938,
+ 0.29882812,
+ 0.28515625,
+ 0.15625,
+ 0.30664062,
+ 0.27734375,
+ 0.14648438,
+ 0.29296875,
+ 0.26953125,
+ ]
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
+ assert max_diff < 1e-4
+
+
+class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = SD3Transformer2DModel
+ expected_memory_use_in_gb = 2
+
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "pooled_projections": torch.randn(
+ (1, 2048),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+ def test_pipeline_inference(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ transformer = self.model_cls.from_single_file(
+ self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
+ )
+ pipe = StableDiffusion3Pipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3.5-medium", transformer=transformer, torch_dtype=self.torch_dtype
+ )
+ pipe.enable_model_cpu_offload()
+
+ prompt = "a cat holding a sign that says hello"
+ output = pipe(
+ prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
+ ).images[0]
+ output_slice = output[:3, :3, :].flatten()
+ expected_slice = np.array(
+ [
+ 0.625,
+ 0.6171875,
+ 0.609375,
+ 0.65625,
+ 0.65234375,
+ 0.640625,
+ 0.6484375,
+ 0.640625,
+ 0.625,
+ 0.6484375,
+ 0.63671875,
+ 0.6484375,
+ 0.66796875,
+ 0.65625,
+ 0.65234375,
+ 0.6640625,
+ 0.6484375,
+ 0.6328125,
+ 0.6640625,
+ 0.6484375,
+ 0.640625,
+ 0.67578125,
+ 0.66015625,
+ 0.62109375,
+ 0.671875,
+ 0.65625,
+ 0.62109375,
+ ]
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
+ assert max_diff < 1e-4
+
+
+class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = AuraFlowTransformer2DModel
+ expected_memory_use_in_gb = 4
+
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 4, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 2048),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+ def test_pipeline_inference(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ transformer = self.model_cls.from_single_file(
+ self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
+ )
+ pipe = AuraFlowPipeline.from_pretrained(
+ "fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=self.torch_dtype
+ )
+ pipe.enable_model_cpu_offload()
+
+ prompt = "a pony holding a sign that says hello"
+ output = pipe(
+ prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
+ ).images[0]
+ output_slice = output[:3, :3, :].flatten()
+ expected_slice = np.array(
+ [
+ 0.46484375,
+ 0.546875,
+ 0.64453125,
+ 0.48242188,
+ 0.53515625,
+ 0.59765625,
+ 0.47070312,
+ 0.5078125,
+ 0.5703125,
+ 0.42773438,
+ 0.50390625,
+ 0.5703125,
+ 0.47070312,
+ 0.515625,
+ 0.57421875,
+ 0.45898438,
+ 0.48632812,
+ 0.53515625,
+ 0.4453125,
+ 0.5078125,
+ 0.56640625,
+ 0.47851562,
+ 0.5234375,
+ 0.57421875,
+ 0.48632812,
+ 0.5234375,
+ 0.56640625,
+ ]
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
+ assert max_diff < 1e-4
diff --git a/tests/quantization/quanto/__init__.py b/tests/quantization/quanto/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py
new file mode 100644
index 000000000000..9eb6958d2183
--- /dev/null
+++ b/tests/quantization/quanto/test_quanto.py
@@ -0,0 +1,328 @@
+import gc
+import tempfile
+import unittest
+
+from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
+from diffusers.models.attention_processor import Attention
+from diffusers.utils import is_optimum_quanto_available, is_torch_available
+from diffusers.utils.testing_utils import (
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_accelerate,
+ require_big_gpu_with_torch_cuda,
+ require_torch_cuda_compatibility,
+ torch_device,
+)
+
+
+if is_optimum_quanto_available():
+ from optimum.quanto import QLinear
+
+if is_torch_available():
+ import torch
+
+ from ..utils import LoRALayer, get_memory_consumption_stat
+
+
+@nightly
+@require_big_gpu_with_torch_cuda
+@require_accelerate
+class QuantoBaseTesterMixin:
+ model_id = None
+ pipeline_model_id = None
+ model_cls = None
+ torch_dtype = torch.bfloat16
+ # the expected reduction in peak memory used compared to an unquantized model expressed as a percentage
+ expected_memory_reduction = 0.0
+ keep_in_fp32_module = ""
+ modules_to_not_convert = ""
+ _test_torch_compile = False
+
+ def setUp(self):
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def tearDown(self):
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def get_dummy_init_kwargs(self):
+ return {"weights_dtype": "float8"}
+
+ def get_dummy_model_init_kwargs(self):
+ return {
+ "pretrained_model_name_or_path": self.model_id,
+ "torch_dtype": self.torch_dtype,
+ "quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()),
+ }
+
+ def test_quanto_layers(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ assert isinstance(module, QLinear)
+
+ def test_quanto_memory_usage(self):
+ inputs = self.get_dummy_inputs()
+ inputs = {
+ k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
+ }
+
+ unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
+ unquantized_model.to(torch_device)
+ unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
+
+ quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ quantized_model.to(torch_device)
+ quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
+
+ assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
+
+ def test_keep_modules_in_fp32(self):
+ r"""
+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
+ Also ensures if inference works.
+ """
+ _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
+ self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
+
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ model.to("cuda")
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ assert module.weight.dtype == torch.float32
+ self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
+
+ def test_modules_to_not_convert(self):
+ init_kwargs = self.get_dummy_model_init_kwargs()
+
+ quantization_config_kwargs = self.get_dummy_init_kwargs()
+ quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
+ quantization_config = QuantoConfig(**quantization_config_kwargs)
+
+ init_kwargs.update({"quantization_config": quantization_config})
+
+ model = self.model_cls.from_pretrained(**init_kwargs)
+ model.to("cuda")
+
+ for name, module in model.named_modules():
+ if name in self.modules_to_not_convert:
+ assert not isinstance(module, QLinear)
+
+ def test_dtype_assignment(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+
+ with self.assertRaises(ValueError):
+ # Tries with a `dtype`
+ model.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device` and `dtype`
+ model.to(device="cuda:0", dtype=torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a cast
+ model.float()
+
+ with self.assertRaises(ValueError):
+ # Tries with a cast
+ model.half()
+
+ # This should work
+ model.to("cuda")
+
+ def test_serialization(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ inputs = self.get_dummy_inputs()
+
+ model.to(torch_device)
+ with torch.no_grad():
+ model_output = model(**inputs)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save_pretrained(tmp_dir)
+ saved_model = self.model_cls.from_pretrained(
+ tmp_dir,
+ torch_dtype=torch.bfloat16,
+ )
+
+ saved_model.to(torch_device)
+ with torch.no_grad():
+ saved_model_output = saved_model(**inputs)
+
+ assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
+
+ def test_torch_compile(self):
+ if not self._test_torch_compile:
+ return
+
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
+
+ model.to(torch_device)
+ with torch.no_grad():
+ model_output = model(**self.get_dummy_inputs()).sample
+
+ compiled_model.to(torch_device)
+ with torch.no_grad():
+ compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
+
+ model_output = model_output.detach().float().cpu().numpy()
+ compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
+
+ max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
+ assert max_diff < 1e-3
+
+ def test_device_map_error(self):
+ with self.assertRaises(ValueError):
+ _ = self.model_cls.from_pretrained(
+ **self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"}
+ )
+
+
+class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
+ model_id = "hf-internal-testing/tiny-flux-transformer"
+ model_cls = FluxTransformer2DModel
+ pipeline_cls = FluxPipeline
+ torch_dtype = torch.bfloat16
+ keep_in_fp32_module = "proj_out"
+ modules_to_not_convert = ["proj_out"]
+ _test_torch_compile = False
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "pooled_projections": torch.randn(
+ (1, 768),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ "img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
+ }
+
+ def get_dummy_training_inputs(self, device=None, seed: int = 0):
+ batch_size = 1
+ num_latent_channels = 4
+ num_image_channels = 3
+ height = width = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ torch.manual_seed(seed)
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
+ device, dtype=torch.bfloat16
+ )
+
+ torch.manual_seed(seed)
+ pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
+
+ timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_prompt_embeds,
+ "txt_ids": text_ids,
+ "img_ids": image_ids,
+ "timestep": timestep,
+ }
+
+ def test_model_cpu_offload(self):
+ init_kwargs = self.get_dummy_init_kwargs()
+ transformer = self.model_cls.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ quantization_config=QuantoConfig(**init_kwargs),
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ )
+ pipe = self.pipeline_cls.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16
+ )
+ pipe.enable_model_cpu_offload(device=torch_device)
+ _ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
+
+ def test_training(self):
+ quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
+ quantized_model = self.model_cls.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
+
+ for param in quantized_model.parameters():
+ # freeze the model as only adapter layers will be trained
+ param.requires_grad = False
+ if param.ndim == 1:
+ param.data = param.data.to(torch.float32)
+
+ for _, module in quantized_model.named_modules():
+ if isinstance(module, Attention):
+ module.to_q = LoRALayer(module.to_q, rank=4)
+ module.to_k = LoRALayer(module.to_k, rank=4)
+ module.to_v = LoRALayer(module.to_v, rank=4)
+
+ with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
+ inputs = self.get_dummy_training_inputs(torch_device)
+ output = quantized_model(**inputs)[0]
+ output.norm().backward()
+
+ for module in quantized_model.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+
+
+class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
+ expected_memory_reduction = 0.6
+
+ def get_dummy_init_kwargs(self):
+ return {"weights_dtype": "float8"}
+
+
+class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
+ expected_memory_reduction = 0.6
+ _test_torch_compile = True
+
+ def get_dummy_init_kwargs(self):
+ return {"weights_dtype": "int8"}
+
+
+@require_torch_cuda_compatibility(8.0)
+class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
+ expected_memory_reduction = 0.55
+
+ def get_dummy_init_kwargs(self):
+ return {"weights_dtype": "int4"}
+
+
+@require_torch_cuda_compatibility(8.0)
+class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
+ expected_memory_reduction = 0.65
+
+ def get_dummy_init_kwargs(self):
+ return {"weights_dtype": "int2"}
diff --git a/tests/quantization/torchao/README.md b/tests/quantization/torchao/README.md
new file mode 100644
index 000000000000..fadc529e12fc
--- /dev/null
+++ b/tests/quantization/torchao/README.md
@@ -0,0 +1,53 @@
+The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/tests/quantization/torchao_integration/).
+
+The benchmarks were run on a single H100. Below is `nvidia-smi`:
+
+```bash
++---------------------------------------------------------------------------------------+
+| NVIDIA-SMI 535.104.12 Driver Version: 535.104.12 CUDA Version: 12.2 |
+|-----------------------------------------+----------------------+----------------------+
+| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
+| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
+| | | MIG M. |
+|=========================================+======================+======================|
+| 0 NVIDIA H100 80GB HBM3 On | 00000000:53:00.0 Off | 0 |
+| N/A 34C P0 69W / 700W | 2MiB / 81559MiB | 0% Default |
+| | | Disabled |
++-----------------------------------------+----------------------+----------------------+
+
++---------------------------------------------------------------------------------------+
+| Processes: |
+| GPU GI CI PID Type Process name GPU Memory |
+| ID ID Usage |
+|=======================================================================================|
+| No running processes found |
++---------------------------------------------------------------------------------------+
+```
+
+The benchmark results for Flux and CogVideoX can be found in [this](https://github.com/huggingface/diffusers/pull/10009) PR.
+
+The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent:
+
+```bash
+HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
+```
+
+`diffusers-cli`:
+
+```bash
+- 🤗 Diffusers version: 0.32.0.dev0
+- Platform: Linux-5.15.0-1049-aws-x86_64-with-glibc2.31
+- Running on Google Colab?: No
+- Python version: 3.10.14
+- PyTorch version (GPU?): 2.6.0.dev20241112+cu121 (False)
+- Flax version (CPU?/GPU?/TPU?): not installed (NA)
+- Jax version: not installed
+- JaxLib version: not installed
+- Huggingface_hub version: 0.26.2
+- Transformers version: 4.46.3
+- Accelerate version: 1.1.1
+- PEFT version: not installed
+- Bitsandbytes version: not installed
+- Safetensors version: 0.4.5
+- xFormers version: not installed
+```
diff --git a/tests/quantization/torchao/__init__.py b/tests/quantization/torchao/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py
new file mode 100644
index 000000000000..0e671307dd18
--- /dev/null
+++ b/tests/quantization/torchao/test_torchao.py
@@ -0,0 +1,839 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+from typing import List
+
+import numpy as np
+from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxPipeline,
+ FluxTransformer2DModel,
+ TorchAoConfig,
+)
+from diffusers.models.attention_processor import Attention
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ is_torch_available,
+ is_torchao_available,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_torch,
+ require_torch_gpu,
+ require_torchao_version_greater_or_equal,
+ slow,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+ from ..utils import LoRALayer, get_memory_consumption_stat
+
+
+if is_torchao_available():
+ from torchao.dtypes import AffineQuantizedTensor
+ from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
+ from torchao.quantization.quant_primitives import MappingType
+ from torchao.utils import get_model_size_in_bytes
+
+
+@require_torch
+@require_torch_gpu
+@require_torchao_version_greater_or_equal("0.7.0")
+class TorchAoConfigTest(unittest.TestCase):
+ def test_to_dict(self):
+ """
+ Makes sure the config format is properly set
+ """
+ quantization_config = TorchAoConfig("int4_weight_only")
+ torchao_orig_config = quantization_config.to_dict()
+
+ for key in torchao_orig_config:
+ self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
+
+ def test_post_init_check(self):
+ """
+ Test kwargs validations in TorchAoConfig
+ """
+ _ = TorchAoConfig("int4_weight_only")
+ with self.assertRaisesRegex(ValueError, "is not supported yet"):
+ _ = TorchAoConfig("uint8")
+
+ with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
+ _ = TorchAoConfig("int4_weight_only", group_size1=32)
+
+ def test_repr(self):
+ """
+ Check that there is no error in the repr
+ """
+ quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
+ expected_repr = """TorchAoConfig {
+ "modules_to_not_convert": [
+ "conv"
+ ],
+ "quant_method": "torchao",
+ "quant_type": "int4_weight_only",
+ "quant_type_kwargs": {
+ "group_size": 8
+ }
+ }""".replace(" ", "").replace("\n", "")
+ quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
+ self.assertEqual(quantization_repr, expected_repr)
+
+ quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
+ expected_repr = """TorchAoConfig {
+ "modules_to_not_convert": null,
+ "quant_method": "torchao",
+ "quant_type": "int4dq",
+ "quant_type_kwargs": {
+ "act_mapping_type": "SYMMETRIC",
+ "group_size": 64
+ }
+ }""".replace(" ", "").replace("\n", "")
+ quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
+ self.assertEqual(quantization_repr, expected_repr)
+
+
+# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
+@require_torch
+@require_torch_gpu
+@require_torchao_version_greater_or_equal("0.7.0")
+class TorchAoTest(unittest.TestCase):
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_components(
+ self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
+ ):
+ transformer = FluxTransformer2DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ )
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
+ text_encoder_2 = T5EncoderModel.from_pretrained(
+ model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
+ )
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
+ tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator().manual_seed(seed)
+
+ inputs = {
+ "prompt": "an astronaut riding a horse in space",
+ "height": 32,
+ "width": 32,
+ "num_inference_steps": 2,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ return inputs
+
+ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
+ batch_size = 1
+ num_latent_channels = 4
+ num_image_channels = 3
+ height = width = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ torch.manual_seed(seed)
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
+ device, dtype=torch.bfloat16
+ )
+
+ torch.manual_seed(seed)
+ pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
+
+ torch.manual_seed(seed)
+ image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
+
+ timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_prompt_embeds,
+ "txt_ids": text_ids,
+ "img_ids": image_ids,
+ "timestep": timestep,
+ }
+
+ def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str):
+ components = self.get_dummy_components(quantization_config, model_id)
+ pipe = FluxPipeline(**components)
+ pipe.to(device=torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+ output_slice = output[-1, -1, -3:, -3:].flatten()
+
+ self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
+
+ def test_quantization(self):
+ for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
+ # fmt: off
+ QUANTIZATION_TYPES_TO_TEST = [
+ ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
+ ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
+ ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
+ ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
+ ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
+ ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
+ ]
+
+ if TorchAoConfig._is_cuda_capability_atleast_8_9():
+ QUANTIZATION_TYPES_TO_TEST.extend([
+ ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
+ ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
+ # =====
+ # The following lead to an internal torch error:
+ # RuntimeError: mat2 shape (32x4 must be divisible by 16
+ # Skip these for now; TODO(aryan): investigate later
+ # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
+ # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
+ # =====
+ # Cutlass fails to initialize for below
+ # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
+ # =====
+ ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
+ ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
+ ])
+ # fmt: on
+
+ for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
+ quant_kwargs = {}
+ if quantization_name in ["uint4wo", "uint7wo"]:
+ # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
+ quant_kwargs.update({"group_size": 16})
+ quantization_config = TorchAoConfig(
+ quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
+ )
+ self._test_quant_type(quantization_config, expected_slice, model_id)
+
+ def test_int4wo_quant_bfloat16_conversion(self):
+ """
+ Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
+ """
+ quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ )
+
+ weight = quantized_model.transformer_blocks[0].ff.net[2].weight
+ self.assertTrue(isinstance(weight, AffineQuantizedTensor))
+ self.assertEqual(weight.quant_min, 0)
+ self.assertEqual(weight.quant_max, 15)
+
+ def test_device_map(self):
+ """
+ Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
+ The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
+ correctly set (in the `hf_device_map` attribute of the model).
+ """
+ custom_device_map_dict = {
+ "time_text_embed": torch_device,
+ "context_embedder": torch_device,
+ "x_embedder": torch_device,
+ "transformer_blocks.0": "cpu",
+ "single_transformer_blocks.0": "disk",
+ "norm_out": torch_device,
+ "proj_out": "cpu",
+ }
+ device_maps = ["auto", custom_device_map_dict]
+
+ inputs = self.get_dummy_tensor_inputs(torch_device)
+ # requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk)
+ expected_slice_auto = np.array(
+ [
+ 0.34179688,
+ -0.03613281,
+ 0.01428223,
+ -0.22949219,
+ -0.49609375,
+ 0.4375,
+ -0.1640625,
+ -0.66015625,
+ 0.43164062,
+ ]
+ )
+ expected_slice_offload = np.array(
+ [0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688]
+ )
+ for device_map in device_maps:
+ if device_map == "auto":
+ expected_slice = expected_slice_auto
+ else:
+ expected_slice = expected_slice_offload
+ with tempfile.TemporaryDirectory() as offload_folder:
+ quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ device_map=device_map,
+ torch_dtype=torch.bfloat16,
+ offload_folder=offload_folder,
+ )
+
+ weight = quantized_model.transformer_blocks[0].ff.net[2].weight
+
+ # Note that when performing cpu/disk offload, the offloaded weights are not quantized, only the weights on the gpu.
+ # This is not the case when the model are already quantized
+ if "transformer_blocks.0" in device_map:
+ self.assertTrue(isinstance(weight, nn.Parameter))
+ else:
+ self.assertTrue(isinstance(weight, AffineQuantizedTensor))
+
+ output = quantized_model(**inputs)[0]
+ output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+
+ with tempfile.TemporaryDirectory() as offload_folder:
+ quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-sharded",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ device_map=device_map,
+ torch_dtype=torch.bfloat16,
+ offload_folder=offload_folder,
+ )
+
+ weight = quantized_model.transformer_blocks[0].ff.net[2].weight
+ if "transformer_blocks.0" in device_map:
+ self.assertTrue(isinstance(weight, nn.Parameter))
+ else:
+ self.assertTrue(isinstance(weight, AffineQuantizedTensor))
+
+ output = quantized_model(**inputs)[0]
+ output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+
+ def test_modules_to_not_convert(self):
+ quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
+ quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ )
+
+ unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
+ self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
+ self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
+ self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)
+
+ quantized_layer = quantized_model_with_not_convert.proj_out
+ self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
+
+ quantization_config = TorchAoConfig("int8_weight_only")
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ )
+
+ size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert)
+ size_quantized = get_model_size_in_bytes(quantized_model)
+
+ self.assertTrue(size_quantized < size_quantized_with_not_convert)
+
+ def test_training(self):
+ quantization_config = TorchAoConfig("int8_weight_only")
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-flux-pipe",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
+
+ for param in quantized_model.parameters():
+ # freeze the model as only adapter layers will be trained
+ param.requires_grad = False
+ if param.ndim == 1:
+ param.data = param.data.to(torch.float32)
+
+ for _, module in quantized_model.named_modules():
+ if isinstance(module, Attention):
+ module.to_q = LoRALayer(module.to_q, rank=4)
+ module.to_k = LoRALayer(module.to_k, rank=4)
+ module.to_v = LoRALayer(module.to_v, rank=4)
+
+ with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
+ inputs = self.get_dummy_tensor_inputs(torch_device)
+ output = quantized_model(**inputs)[0]
+ output.norm().backward()
+
+ for module in quantized_model.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+ self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
+
+ @nightly
+ def test_torch_compile(self):
+ r"""Test that verifies if torch.compile works with torchao quantization."""
+ for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
+ quantization_config = TorchAoConfig("int8_weight_only")
+ components = self.get_dummy_components(quantization_config, model_id=model_id)
+ pipe = FluxPipeline(**components)
+ pipe.to(device=torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ normal_output = pipe(**inputs)[0].flatten()[-32:]
+
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
+ inputs = self.get_dummy_inputs(torch_device)
+ compile_output = pipe(**inputs)[0].flatten()[-32:]
+
+ # Note: Seems to require higher tolerance
+ self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))
+
+ def test_memory_footprint(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
+ transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
+ transformer_int4wo_gs32 = self.get_dummy_components(
+ TorchAoConfig("int4wo", group_size=32), model_id=model_id
+ )["transformer"]
+ transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
+ transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
+
+ # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
+ for block in transformer_int4wo.transformer_blocks:
+ self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor))
+ self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor))
+
+ # Will quantize all the linear layers except x_embedder
+ for name, module in transformer_int4wo_gs32.named_modules():
+ if isinstance(module, nn.Linear) and name not in ["x_embedder"]:
+ self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
+
+ # Will quantize all the linear layers
+ for module in transformer_int8wo.modules():
+ if isinstance(module, nn.Linear):
+ self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
+
+ total_int4wo = get_model_size_in_bytes(transformer_int4wo)
+ total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
+ total_int8wo = get_model_size_in_bytes(transformer_int8wo)
+ total_bf16 = get_model_size_in_bytes(transformer_bf16)
+
+ # TODO: refactor to align with other quantization tests
+ # Latter has smaller group size, so more groups -> more scales and zero points
+ self.assertTrue(total_int4wo < total_int4wo_gs32)
+ # int8 quantizes more layers compare to int4 with default group size
+ self.assertTrue(total_int8wo < total_int4wo)
+ # int4wo does not quantize too many layers because of default group size, but for the layers it does
+ # there is additional overhead of scales and zero points
+ self.assertTrue(total_bf16 < total_int4wo)
+
+ def test_model_memory_usage(self):
+ model_id = "hf-internal-testing/tiny-flux-pipe"
+ expected_memory_saving_ratio = 2.0
+
+ inputs = self.get_dummy_tensor_inputs(device=torch_device)
+
+ transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
+ transformer_bf16.to(torch_device)
+ unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
+ del transformer_bf16
+
+ transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
+ transformer_int8wo.to(torch_device)
+ quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
+ assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
+
+ def test_wrong_config(self):
+ with self.assertRaises(ValueError):
+ self.get_dummy_components(TorchAoConfig("int42"))
+
+ def test_sequential_cpu_offload(self):
+ r"""
+ A test that checks if inference runs as expected when sequential cpu offloading is enabled.
+ """
+ quantization_config = TorchAoConfig("int8wo")
+ components = self.get_dummy_components(quantization_config)
+ pipe = FluxPipeline(**components)
+ pipe.enable_sequential_cpu_offload()
+
+ inputs = self.get_dummy_inputs(torch_device)
+ _ = pipe(**inputs)
+
+
+# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
+@require_torch
+@require_torch_gpu
+@require_torchao_version_greater_or_equal("0.7.0")
+class TorchAoSerializationTest(unittest.TestCase):
+ model_name = "hf-internal-testing/tiny-flux-pipe"
+
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
+ quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
+ quantized_model = FluxTransformer2DModel.from_pretrained(
+ self.model_name,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ )
+ return quantized_model.to(device)
+
+ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
+ batch_size = 1
+ num_latent_channels = 4
+ num_image_channels = 3
+ height = width = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ torch.manual_seed(seed)
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
+ device, dtype=torch.bfloat16
+ )
+ pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
+ image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
+ timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_prompt_embeds,
+ "txt_ids": text_ids,
+ "img_ids": image_ids,
+ "timestep": timestep,
+ }
+
+ def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
+ quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
+ inputs = self.get_dummy_tensor_inputs(torch_device)
+ output = quantized_model(**inputs)[0]
+ output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
+ weight = quantized_model.transformer_blocks[0].ff.net[2].weight
+ self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+
+ def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
+ quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
+ loaded_quantized_model = FluxTransformer2DModel.from_pretrained(
+ tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
+ ).to(device=torch_device)
+
+ inputs = self.get_dummy_tensor_inputs(torch_device)
+ output = loaded_quantized_model(**inputs)[0]
+
+ output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
+ self.assertTrue(
+ isinstance(
+ loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
+ )
+ )
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+
+ def test_int_a8w8_cuda(self):
+ quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
+ expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
+ device = "cuda"
+ self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
+ self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+
+ def test_int_a16w8_cuda(self):
+ quant_method, quant_method_kwargs = "int8_weight_only", {}
+ expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
+ device = "cuda"
+ self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
+ self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+
+ def test_int_a8w8_cpu(self):
+ quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
+ expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
+ device = "cpu"
+ self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
+ self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+
+ def test_int_a16w8_cpu(self):
+ quant_method, quant_method_kwargs = "int8_weight_only", {}
+ expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
+ device = "cpu"
+ self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
+ self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+
+
+# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
+@require_torch
+@require_torch_gpu
+@require_torchao_version_greater_or_equal("0.7.0")
+@slow
+@nightly
+class SlowTorchAoTests(unittest.TestCase):
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_components(self, quantization_config: TorchAoConfig):
+ # This is just for convenience, so that we can modify it at one place for custom environments and locally testing
+ cache_dir = None
+ model_id = "black-forest-labs/FLUX.1-dev"
+ transformer = FluxTransformer2DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ cache_dir=cache_dir,
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir
+ )
+ text_encoder_2 = T5EncoderModel.from_pretrained(
+ model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir
+ )
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
+ tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator().manual_seed(seed)
+
+ inputs = {
+ "prompt": "an astronaut riding a horse in space",
+ "height": 512,
+ "width": 512,
+ "num_inference_steps": 20,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ return inputs
+
+ def _test_quant_type(self, quantization_config, expected_slice):
+ components = self.get_dummy_components(quantization_config)
+ pipe = FluxPipeline(**components)
+ pipe.enable_model_cpu_offload()
+
+ weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
+ self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0].flatten()
+ output_slice = np.concatenate((output[:16], output[-16:]))
+ self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
+
+ def test_quantization(self):
+ # fmt: off
+ QUANTIZATION_TYPES_TO_TEST = [
+ ("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
+ ("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
+ ]
+
+ if TorchAoConfig._is_cuda_capability_atleast_8_9():
+ QUANTIZATION_TYPES_TO_TEST.extend([
+ ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
+ ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
+ ])
+ # fmt: on
+
+ for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
+ quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
+ self._test_quant_type(quantization_config, expected_slice)
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_serialization_int8wo(self):
+ quantization_config = TorchAoConfig("int8wo")
+ components = self.get_dummy_components(quantization_config)
+ pipe = FluxPipeline(**components)
+ pipe.enable_model_cpu_offload()
+
+ weight = pipe.transformer.x_embedder.weight
+ self.assertTrue(isinstance(weight, AffineQuantizedTensor))
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0].flatten()[:128]
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False)
+ pipe.remove_all_hooks()
+ del pipe.transformer
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ transformer = FluxTransformer2DModel.from_pretrained(
+ tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
+ )
+ pipe.transformer = transformer
+ pipe.enable_model_cpu_offload()
+
+ weight = transformer.x_embedder.weight
+ self.assertTrue(isinstance(weight, AffineQuantizedTensor))
+
+ loaded_output = pipe(**inputs)[0].flatten()[:128]
+ # Seems to require higher tolerance depending on which machine it is being run.
+ # A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of
+ # 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04,
+ # on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here.
+ self.assertTrue(np.allclose(output, loaded_output, atol=0.06))
+
+ def test_memory_footprint_int4wo(self):
+ # The original checkpoints are in bf16 and about 24 GB
+ expected_memory_in_gb = 6.0
+ quantization_config = TorchAoConfig("int4wo")
+ cache_dir = None
+ transformer = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ cache_dir=cache_dir,
+ )
+ int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
+ self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb)
+
+ def test_memory_footprint_int8wo(self):
+ # The original checkpoints are in bf16 and about 24 GB
+ expected_memory_in_gb = 12.0
+ quantization_config = TorchAoConfig("int8wo")
+ cache_dir = None
+ transformer = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ cache_dir=cache_dir,
+ )
+ int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
+ self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)
+
+
+@require_torch
+@require_torch_gpu
+@require_torchao_version_greater_or_equal("0.7.0")
+@slow
+@nightly
+class SlowTorchAoPreserializedModelTests(unittest.TestCase):
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_inputs(self, device: torch.device, seed: int = 0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator().manual_seed(seed)
+
+ inputs = {
+ "prompt": "an astronaut riding a horse in space",
+ "height": 512,
+ "width": 512,
+ "num_inference_steps": 20,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ return inputs
+
+ def test_transformer_int8wo(self):
+ # fmt: off
+ expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703])
+ # fmt: on
+
+ # This is just for convenience, so that we can modify it at one place for custom environments and locally testing
+ cache_dir = None
+ transformer = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer",
+ torch_dtype=torch.bfloat16,
+ use_safetensors=False,
+ cache_dir=cache_dir,
+ )
+ pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir
+ )
+ pipe.enable_model_cpu_offload()
+
+ # Verify that all linear layer weights are quantized
+ for name, module in pipe.transformer.named_modules():
+ if isinstance(module, nn.Linear):
+ self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))
+
+ # Verify outputs match expected slice
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0].flatten()
+ output_slice = np.concatenate((output[:16], output[-16:]))
+ self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py
new file mode 100644
index 000000000000..04ebf9e159f4
--- /dev/null
+++ b/tests/quantization/utils.py
@@ -0,0 +1,38 @@
+from diffusers.utils import is_torch_available
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+ class LoRALayer(nn.Module):
+ """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
+
+ Taken from
+ https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
+ """
+
+ def __init__(self, module: nn.Module, rank: int):
+ super().__init__()
+ self.module = module
+ self.adapter = nn.Sequential(
+ nn.Linear(module.in_features, rank, bias=False),
+ nn.Linear(rank, module.out_features, bias=False),
+ )
+ small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
+ nn.init.normal_(self.adapter[0].weight, std=small_std)
+ nn.init.zeros_(self.adapter[1].weight)
+ self.adapter.to(module.weight.device)
+
+ def forward(self, input, *args, **kwargs):
+ return self.module(input, *args, **kwargs) + self.adapter(input)
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def get_memory_consumption_stat(model, inputs):
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.empty_cache()
+
+ model(**inputs)
+ max_memory_mem_allocated = torch.cuda.max_memory_allocated()
+ return max_memory_mem_allocated
diff --git a/tests/remote/__init__.py b/tests/remote/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py
new file mode 100644
index 000000000000..cec96e729a48
--- /dev/null
+++ b/tests/remote/test_remote_decode.py
@@ -0,0 +1,536 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from typing import Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.utils.constants import (
+ DECODE_ENDPOINT_FLUX,
+ DECODE_ENDPOINT_HUNYUAN_VIDEO,
+ DECODE_ENDPOINT_SD_V1,
+ DECODE_ENDPOINT_SD_XL,
+)
+from diffusers.utils.remote_utils import (
+ remote_decode,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ slow,
+ torch_all_close,
+ torch_device,
+)
+from diffusers.video_processor import VideoProcessor
+
+
+enable_full_determinism()
+
+
+class RemoteAutoencoderKLMixin:
+ shape: Tuple[int, ...] = None
+ out_hw: Tuple[int, int] = None
+ endpoint: str = None
+ dtype: torch.dtype = None
+ scaling_factor: float = None
+ shift_factor: float = None
+ processor_cls: Union[VaeImageProcessor, VideoProcessor] = None
+ output_pil_slice: torch.Tensor = None
+ output_pt_slice: torch.Tensor = None
+ partial_postprocess_return_pt_slice: torch.Tensor = None
+ return_pt_slice: torch.Tensor = None
+ width: int = None
+ height: int = None
+
+ def get_dummy_inputs(self):
+ inputs = {
+ "endpoint": self.endpoint,
+ "tensor": torch.randn(
+ self.shape,
+ device=torch_device,
+ dtype=self.dtype,
+ generator=torch.Generator(torch_device).manual_seed(13),
+ ),
+ "scaling_factor": self.scaling_factor,
+ "shift_factor": self.shift_factor,
+ "height": self.height,
+ "width": self.width,
+ }
+ return inputs
+
+ def test_no_scaling(self):
+ inputs = self.get_dummy_inputs()
+ if inputs["scaling_factor"] is not None:
+ inputs["tensor"] = inputs["tensor"] / inputs["scaling_factor"]
+ inputs["scaling_factor"] = None
+ if inputs["shift_factor"] is not None:
+ inputs["tensor"] = inputs["tensor"] + inputs["shift_factor"]
+ inputs["shift_factor"] = None
+ processor = self.processor_cls()
+ output = remote_decode(
+ output_type="pt",
+ # required for now, will be removed in next update
+ do_scaling=False,
+ processor=processor,
+ **inputs,
+ )
+ assert isinstance(output, PIL.Image.Image)
+ self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
+ self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
+ self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
+ output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten())
+ # Increased tolerance for Flux Packed diff [1, 0, 1, 0, 0, 0, 0, 0, 0]
+ self.assertTrue(
+ torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
+ f"{output_slice}",
+ )
+
+ def test_output_type_pt(self):
+ inputs = self.get_dummy_inputs()
+ processor = self.processor_cls()
+ output = remote_decode(output_type="pt", processor=processor, **inputs)
+ assert isinstance(output, PIL.Image.Image)
+ self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
+ self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
+ self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
+ output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten())
+ self.assertTrue(
+ torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}"
+ )
+
+ # output is visually the same, slice is flaky?
+ def test_output_type_pil(self):
+ inputs = self.get_dummy_inputs()
+ output = remote_decode(output_type="pil", **inputs)
+ self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
+ self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
+ self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
+
+ def test_output_type_pil_image_format(self):
+ inputs = self.get_dummy_inputs()
+ output = remote_decode(output_type="pil", image_format="png", **inputs)
+ self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
+ self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
+ self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
+ self.assertEqual(output.format, "png", f"Expected image format `png`, got {output.format}")
+ output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten())
+ self.assertTrue(
+ torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}"
+ )
+
+ def test_output_type_pt_partial_postprocess(self):
+ inputs = self.get_dummy_inputs()
+ output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
+ self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
+ self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
+ self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
+ output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten())
+ self.assertTrue(
+ torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}"
+ )
+
+ def test_output_type_pt_return_type_pt(self):
+ inputs = self.get_dummy_inputs()
+ output = remote_decode(output_type="pt", return_type="pt", **inputs)
+ self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}")
+ self.assertEqual(
+ output.shape[2], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[2]}"
+ )
+ self.assertEqual(
+ output.shape[3], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[3]}"
+ )
+ output_slice = output[0, 0, -3:, -3:].flatten()
+ self.assertTrue(
+ torch_all_close(output_slice, self.return_pt_slice.to(output_slice.dtype), rtol=1e-3, atol=1e-3),
+ f"{output_slice}",
+ )
+
+ def test_output_type_pt_partial_postprocess_return_type_pt(self):
+ inputs = self.get_dummy_inputs()
+ output = remote_decode(output_type="pt", partial_postprocess=True, return_type="pt", **inputs)
+ self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}")
+ self.assertEqual(
+ output.shape[1], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[1]}"
+ )
+ self.assertEqual(
+ output.shape[2], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[2]}"
+ )
+ output_slice = output[0, -3:, -3:, 0].flatten().cpu()
+ self.assertTrue(
+ torch_all_close(output_slice, self.partial_postprocess_return_pt_slice.to(output_slice.dtype), rtol=1e-2),
+ f"{output_slice}",
+ )
+
+ def test_do_scaling_deprecation(self):
+ inputs = self.get_dummy_inputs()
+ inputs.pop("scaling_factor", None)
+ inputs.pop("shift_factor", None)
+ with self.assertWarns(FutureWarning) as warning:
+ _ = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
+ self.assertEqual(
+ str(warning.warnings[0].message),
+ "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
+ str(warning.warnings[0].message),
+ )
+
+ def test_input_tensor_type_base64_deprecation(self):
+ inputs = self.get_dummy_inputs()
+ with self.assertWarns(FutureWarning) as warning:
+ _ = remote_decode(output_type="pt", input_tensor_type="base64", partial_postprocess=True, **inputs)
+ self.assertEqual(
+ str(warning.warnings[0].message),
+ "input_tensor_type='base64' is deprecated. Using `binary`.",
+ str(warning.warnings[0].message),
+ )
+
+ def test_output_tensor_type_base64_deprecation(self):
+ inputs = self.get_dummy_inputs()
+ with self.assertWarns(FutureWarning) as warning:
+ _ = remote_decode(output_type="pt", output_tensor_type="base64", partial_postprocess=True, **inputs)
+ self.assertEqual(
+ str(warning.warnings[0].message),
+ "output_tensor_type='base64' is deprecated. Using `binary`.",
+ str(warning.warnings[0].message),
+ )
+
+
+class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin):
+ def test_no_scaling(self):
+ inputs = self.get_dummy_inputs()
+ if inputs["scaling_factor"] is not None:
+ inputs["tensor"] = inputs["tensor"] / inputs["scaling_factor"]
+ inputs["scaling_factor"] = None
+ if inputs["shift_factor"] is not None:
+ inputs["tensor"] = inputs["tensor"] + inputs["shift_factor"]
+ inputs["shift_factor"] = None
+ processor = self.processor_cls()
+ output = remote_decode(
+ output_type="pt",
+ # required for now, will be removed in next update
+ do_scaling=False,
+ processor=processor,
+ **inputs,
+ )
+ self.assertTrue(
+ isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
+ f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
+ )
+ self.assertEqual(
+ output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
+ )
+ self.assertEqual(
+ output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
+ )
+ output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten())
+ self.assertTrue(
+ torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
+ f"{output_slice}",
+ )
+
+ def test_output_type_pt(self):
+ inputs = self.get_dummy_inputs()
+ processor = self.processor_cls()
+ output = remote_decode(output_type="pt", processor=processor, **inputs)
+ self.assertTrue(
+ isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
+ f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
+ )
+ self.assertEqual(
+ output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
+ )
+ self.assertEqual(
+ output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
+ )
+ output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten())
+ self.assertTrue(
+ torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
+ f"{output_slice}",
+ )
+
+ # output is visually the same, slice is flaky?
+ def test_output_type_pil(self):
+ inputs = self.get_dummy_inputs()
+ processor = self.processor_cls()
+ output = remote_decode(output_type="pil", processor=processor, **inputs)
+ self.assertTrue(
+ isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
+ f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
+ )
+ self.assertEqual(
+ output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
+ )
+ self.assertEqual(
+ output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
+ )
+
+ def test_output_type_pil_image_format(self):
+ inputs = self.get_dummy_inputs()
+ processor = self.processor_cls()
+ output = remote_decode(output_type="pil", processor=processor, image_format="png", **inputs)
+ self.assertTrue(
+ isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
+ f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
+ )
+ self.assertEqual(
+ output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
+ )
+ self.assertEqual(
+ output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
+ )
+ output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten())
+ self.assertTrue(
+ torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
+ f"{output_slice}",
+ )
+
+ def test_output_type_pt_partial_postprocess(self):
+ inputs = self.get_dummy_inputs()
+ output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
+ self.assertTrue(
+ isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
+ f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
+ )
+ self.assertEqual(
+ output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
+ )
+ self.assertEqual(
+ output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
+ )
+ output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten())
+ self.assertTrue(
+ torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
+ f"{output_slice}",
+ )
+
+ def test_output_type_pt_return_type_pt(self):
+ inputs = self.get_dummy_inputs()
+ output = remote_decode(output_type="pt", return_type="pt", **inputs)
+ self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}")
+ self.assertEqual(
+ output.shape[3], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[2]}"
+ )
+ self.assertEqual(
+ output.shape[4], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[3]}"
+ )
+ output_slice = output[0, 0, 0, -3:, -3:].flatten()
+ self.assertTrue(
+ torch_all_close(output_slice, self.return_pt_slice.to(output_slice.dtype), rtol=1e-3, atol=1e-3),
+ f"{output_slice}",
+ )
+
+ def test_output_type_mp4(self):
+ inputs = self.get_dummy_inputs()
+ output = remote_decode(output_type="mp4", return_type="mp4", **inputs)
+ self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}")
+
+
+class RemoteAutoencoderKLSDv1Tests(
+ RemoteAutoencoderKLMixin,
+ unittest.TestCase,
+):
+ shape = (
+ 1,
+ 4,
+ 64,
+ 64,
+ )
+ out_hw = (
+ 512,
+ 512,
+ )
+ endpoint = DECODE_ENDPOINT_SD_V1
+ dtype = torch.float16
+ scaling_factor = 0.18215
+ shift_factor = None
+ processor_cls = VaeImageProcessor
+ output_pt_slice = torch.tensor([31, 15, 11, 55, 30, 21, 66, 42, 30], dtype=torch.uint8)
+ partial_postprocess_return_pt_slice = torch.tensor([100, 130, 99, 133, 106, 112, 97, 100, 121], dtype=torch.uint8)
+ return_pt_slice = torch.tensor([-0.2177, 0.0217, -0.2258, 0.0412, -0.1687, -0.1232, -0.2416, -0.2130, -0.0543])
+
+
+class RemoteAutoencoderKLSDXLTests(
+ RemoteAutoencoderKLMixin,
+ unittest.TestCase,
+):
+ shape = (
+ 1,
+ 4,
+ 128,
+ 128,
+ )
+ out_hw = (
+ 1024,
+ 1024,
+ )
+ endpoint = DECODE_ENDPOINT_SD_XL
+ dtype = torch.float16
+ scaling_factor = 0.13025
+ shift_factor = None
+ processor_cls = VaeImageProcessor
+ output_pt_slice = torch.tensor([104, 52, 23, 114, 61, 35, 108, 87, 38], dtype=torch.uint8)
+ partial_postprocess_return_pt_slice = torch.tensor([77, 86, 89, 49, 60, 75, 52, 65, 78], dtype=torch.uint8)
+ return_pt_slice = torch.tensor([-0.3945, -0.3289, -0.2993, -0.6177, -0.5259, -0.4119, -0.5898, -0.4863, -0.3845])
+
+
+class RemoteAutoencoderKLFluxTests(
+ RemoteAutoencoderKLMixin,
+ unittest.TestCase,
+):
+ shape = (
+ 1,
+ 16,
+ 128,
+ 128,
+ )
+ out_hw = (
+ 1024,
+ 1024,
+ )
+ endpoint = DECODE_ENDPOINT_FLUX
+ dtype = torch.bfloat16
+ scaling_factor = 0.3611
+ shift_factor = 0.1159
+ processor_cls = VaeImageProcessor
+ output_pt_slice = torch.tensor([110, 72, 91, 62, 35, 52, 69, 55, 69], dtype=torch.uint8)
+ partial_postprocess_return_pt_slice = torch.tensor(
+ [202, 203, 203, 197, 195, 193, 189, 188, 178], dtype=torch.uint8
+ )
+ return_pt_slice = torch.tensor([0.5820, 0.5962, 0.5898, 0.5439, 0.5327, 0.5112, 0.4797, 0.4773, 0.3984])
+
+
+class RemoteAutoencoderKLFluxPackedTests(
+ RemoteAutoencoderKLMixin,
+ unittest.TestCase,
+):
+ shape = (
+ 1,
+ 4096,
+ 64,
+ )
+ out_hw = (
+ 1024,
+ 1024,
+ )
+ height = 1024
+ width = 1024
+ endpoint = DECODE_ENDPOINT_FLUX
+ dtype = torch.bfloat16
+ scaling_factor = 0.3611
+ shift_factor = 0.1159
+ processor_cls = VaeImageProcessor
+ # slices are different due to randn on different shape. we can pack the latent instead if we want the same
+ output_pt_slice = torch.tensor([96, 116, 157, 45, 67, 104, 34, 56, 89], dtype=torch.uint8)
+ partial_postprocess_return_pt_slice = torch.tensor(
+ [168, 212, 202, 155, 191, 185, 150, 180, 168], dtype=torch.uint8
+ )
+ return_pt_slice = torch.tensor([0.3198, 0.6631, 0.5864, 0.2131, 0.4944, 0.4482, 0.1776, 0.4153, 0.3176])
+
+
+class RemoteAutoencoderKLHunyuanVideoTests(
+ RemoteAutoencoderKLHunyuanVideoMixin,
+ unittest.TestCase,
+):
+ shape = (
+ 1,
+ 16,
+ 3,
+ 40,
+ 64,
+ )
+ out_hw = (
+ 320,
+ 512,
+ )
+ endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO
+ dtype = torch.float16
+ scaling_factor = 0.476986
+ processor_cls = VideoProcessor
+ output_pt_slice = torch.tensor([112, 92, 85, 112, 93, 85, 112, 94, 85], dtype=torch.uint8)
+ partial_postprocess_return_pt_slice = torch.tensor(
+ [149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8
+ )
+ return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])
+
+
+class RemoteAutoencoderKLSlowTestMixin:
+ channels: int = 4
+ endpoint: str = None
+ dtype: torch.dtype = None
+ scaling_factor: float = None
+ shift_factor: float = None
+ width: int = None
+ height: int = None
+
+ def get_dummy_inputs(self):
+ inputs = {
+ "endpoint": self.endpoint,
+ "scaling_factor": self.scaling_factor,
+ "shift_factor": self.shift_factor,
+ "height": self.height,
+ "width": self.width,
+ }
+ return inputs
+
+ def test_multi_res(self):
+ inputs = self.get_dummy_inputs()
+ for height in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}:
+ for width in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}:
+ inputs["tensor"] = torch.randn(
+ (1, self.channels, height // 8, width // 8),
+ device=torch_device,
+ dtype=self.dtype,
+ generator=torch.Generator(torch_device).manual_seed(13),
+ )
+ inputs["height"] = height
+ inputs["width"] = width
+ output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
+ output.save(f"test_multi_res_{height}_{width}.png")
+
+
+@slow
+class RemoteAutoencoderKLSDv1SlowTests(
+ RemoteAutoencoderKLSlowTestMixin,
+ unittest.TestCase,
+):
+ endpoint = DECODE_ENDPOINT_SD_V1
+ dtype = torch.float16
+ scaling_factor = 0.18215
+ shift_factor = None
+
+
+@slow
+class RemoteAutoencoderKLSDXLSlowTests(
+ RemoteAutoencoderKLSlowTestMixin,
+ unittest.TestCase,
+):
+ endpoint = DECODE_ENDPOINT_SD_XL
+ dtype = torch.float16
+ scaling_factor = 0.13025
+ shift_factor = None
+
+
+@slow
+class RemoteAutoencoderKLFluxSlowTests(
+ RemoteAutoencoderKLSlowTestMixin,
+ unittest.TestCase,
+):
+ channels = 16
+ endpoint = DECODE_ENDPOINT_FLUX
+ dtype = torch.bfloat16
+ scaling_factor = 0.3611
+ shift_factor = 0.1159
diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py
new file mode 100644
index 000000000000..62ed97ee8f49
--- /dev/null
+++ b/tests/remote/test_remote_encode.py
@@ -0,0 +1,224 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import PIL.Image
+import torch
+
+from diffusers.utils import load_image
+from diffusers.utils.constants import (
+ DECODE_ENDPOINT_FLUX,
+ DECODE_ENDPOINT_SD_V1,
+ DECODE_ENDPOINT_SD_XL,
+ ENCODE_ENDPOINT_FLUX,
+ ENCODE_ENDPOINT_SD_V1,
+ ENCODE_ENDPOINT_SD_XL,
+)
+from diffusers.utils.remote_utils import (
+ remote_decode,
+ remote_encode,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ slow,
+)
+
+
+enable_full_determinism()
+
+IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true"
+
+
+class RemoteAutoencoderKLEncodeMixin:
+ channels: int = None
+ endpoint: str = None
+ decode_endpoint: str = None
+ dtype: torch.dtype = None
+ scaling_factor: float = None
+ shift_factor: float = None
+ image: PIL.Image.Image = None
+
+ def get_dummy_inputs(self):
+ if self.image is None:
+ self.image = load_image(IMAGE)
+ inputs = {
+ "endpoint": self.endpoint,
+ "image": self.image,
+ "scaling_factor": self.scaling_factor,
+ "shift_factor": self.shift_factor,
+ }
+ return inputs
+
+ def test_image_input(self):
+ inputs = self.get_dummy_inputs()
+ height, width = inputs["image"].height, inputs["image"].width
+ output = remote_encode(**inputs)
+ self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
+ decoded = remote_decode(
+ tensor=output,
+ endpoint=self.decode_endpoint,
+ scaling_factor=self.scaling_factor,
+ shift_factor=self.shift_factor,
+ image_format="png",
+ )
+ self.assertEqual(decoded.height, height)
+ self.assertEqual(decoded.width, width)
+ # image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten())
+ # decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten())
+ # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?
+
+
+class RemoteAutoencoderKLSDv1Tests(
+ RemoteAutoencoderKLEncodeMixin,
+ unittest.TestCase,
+):
+ channels = 4
+ endpoint = ENCODE_ENDPOINT_SD_V1
+ decode_endpoint = DECODE_ENDPOINT_SD_V1
+ dtype = torch.float16
+ scaling_factor = 0.18215
+ shift_factor = None
+
+
+class RemoteAutoencoderKLSDXLTests(
+ RemoteAutoencoderKLEncodeMixin,
+ unittest.TestCase,
+):
+ channels = 4
+ endpoint = ENCODE_ENDPOINT_SD_XL
+ decode_endpoint = DECODE_ENDPOINT_SD_XL
+ dtype = torch.float16
+ scaling_factor = 0.13025
+ shift_factor = None
+
+
+class RemoteAutoencoderKLFluxTests(
+ RemoteAutoencoderKLEncodeMixin,
+ unittest.TestCase,
+):
+ channels = 16
+ endpoint = ENCODE_ENDPOINT_FLUX
+ decode_endpoint = DECODE_ENDPOINT_FLUX
+ dtype = torch.bfloat16
+ scaling_factor = 0.3611
+ shift_factor = 0.1159
+
+
+class RemoteAutoencoderKLEncodeSlowTestMixin:
+ channels: int = 4
+ endpoint: str = None
+ decode_endpoint: str = None
+ dtype: torch.dtype = None
+ scaling_factor: float = None
+ shift_factor: float = None
+ image: PIL.Image.Image = None
+
+ def get_dummy_inputs(self):
+ if self.image is None:
+ self.image = load_image(IMAGE)
+ inputs = {
+ "endpoint": self.endpoint,
+ "image": self.image,
+ "scaling_factor": self.scaling_factor,
+ "shift_factor": self.shift_factor,
+ }
+ return inputs
+
+ def test_multi_res(self):
+ inputs = self.get_dummy_inputs()
+ for height in {
+ 320,
+ 512,
+ 640,
+ 704,
+ 896,
+ 1024,
+ 1208,
+ 1384,
+ 1536,
+ 1608,
+ 1864,
+ 2048,
+ }:
+ for width in {
+ 320,
+ 512,
+ 640,
+ 704,
+ 896,
+ 1024,
+ 1208,
+ 1384,
+ 1536,
+ 1608,
+ 1864,
+ 2048,
+ }:
+ inputs["image"] = inputs["image"].resize(
+ (
+ width,
+ height,
+ )
+ )
+ output = remote_encode(**inputs)
+ self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
+ decoded = remote_decode(
+ tensor=output,
+ endpoint=self.decode_endpoint,
+ scaling_factor=self.scaling_factor,
+ shift_factor=self.shift_factor,
+ image_format="png",
+ )
+ self.assertEqual(decoded.height, height)
+ self.assertEqual(decoded.width, width)
+ decoded.save(f"test_multi_res_{height}_{width}.png")
+
+
+@slow
+class RemoteAutoencoderKLSDv1SlowTests(
+ RemoteAutoencoderKLEncodeSlowTestMixin,
+ unittest.TestCase,
+):
+ endpoint = ENCODE_ENDPOINT_SD_V1
+ decode_endpoint = DECODE_ENDPOINT_SD_V1
+ dtype = torch.float16
+ scaling_factor = 0.18215
+ shift_factor = None
+
+
+@slow
+class RemoteAutoencoderKLSDXLSlowTests(
+ RemoteAutoencoderKLEncodeSlowTestMixin,
+ unittest.TestCase,
+):
+ endpoint = ENCODE_ENDPOINT_SD_XL
+ decode_endpoint = DECODE_ENDPOINT_SD_XL
+ dtype = torch.float16
+ scaling_factor = 0.13025
+ shift_factor = None
+
+
+@slow
+class RemoteAutoencoderKLFluxSlowTests(
+ RemoteAutoencoderKLEncodeSlowTestMixin,
+ unittest.TestCase,
+):
+ channels = 16
+ endpoint = ENCODE_ENDPOINT_FLUX
+ decode_endpoint = DECODE_ENDPOINT_FLUX
+ dtype = torch.bfloat16
+ scaling_factor = 0.3611
+ shift_factor = 0.1159
diff --git a/tests/schedulers/test_scheduler_ddim_inverse.py b/tests/schedulers/test_scheduler_ddim_inverse.py
index 696f57644a83..81d53f1b4778 100644
--- a/tests/schedulers/test_scheduler_ddim_inverse.py
+++ b/tests/schedulers/test_scheduler_ddim_inverse.py
@@ -1,3 +1,5 @@
+import unittest
+
import torch
from diffusers import DDIMInverseScheduler
@@ -95,6 +97,7 @@ def test_inference_steps(self):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
+ @unittest.skip("Test not supported.")
def test_add_noise_device(self):
pass
diff --git a/tests/schedulers/test_scheduler_deis.py b/tests/schedulers/test_scheduler_deis.py
index b2823a0cb47e..048bde51c366 100644
--- a/tests/schedulers/test_scheduler_deis.py
+++ b/tests/schedulers/test_scheduler_deis.py
@@ -1,4 +1,5 @@
import tempfile
+import unittest
import torch
@@ -57,6 +58,7 @@ def check_over_configs(self, time_step=0, **config):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+ @unittest.skip("Test not supported.")
def test_from_save_pretrained(self):
pass
@@ -263,3 +265,9 @@ def test_full_loop_with_noise(self):
assert abs(result_sum.item() - 315.3016) < 1e-2, f" expected result sum 315.3016, but get {result_sum}"
assert abs(result_mean.item() - 0.41054) < 1e-3, f" expected result mean 0.41054, but get {result_mean}"
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py
index ef407eaa3dc9..55b3202ad0be 100644
--- a/tests/schedulers/test_scheduler_dpm_multi.py
+++ b/tests/schedulers/test_scheduler_dpm_multi.py
@@ -1,4 +1,5 @@
import tempfile
+import unittest
import torch
@@ -67,6 +68,7 @@ def check_over_configs(self, time_step=0, **config):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+ @unittest.skip("Test not supported.")
def test_from_save_pretrained(self):
pass
@@ -358,3 +360,9 @@ def test_custom_timesteps(self):
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_dpm_multi_inverse.py b/tests/schedulers/test_scheduler_dpm_multi_inverse.py
index 014c901680e3..0eced957190c 100644
--- a/tests/schedulers/test_scheduler_dpm_multi_inverse.py
+++ b/tests/schedulers/test_scheduler_dpm_multi_inverse.py
@@ -265,3 +265,9 @@ def test_unique_timesteps(self, **config):
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py
index 253a0a478b41..69b611173423 100644
--- a/tests/schedulers/test_scheduler_dpm_sde.py
+++ b/tests/schedulers/test_scheduler_dpm_sde.py
@@ -64,7 +64,7 @@ def test_full_loop_no_noise(self):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
- elif torch_device in ["cuda"]:
+ elif torch_device in ["cuda", "xpu"]:
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
else:
@@ -96,7 +96,7 @@ def test_full_loop_with_v_prediction(self):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
- elif torch_device in ["cuda"]:
+ elif torch_device in ["cuda", "xpu"]:
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
else:
@@ -127,7 +127,7 @@ def test_full_loop_device(self):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
- elif torch_device in ["cuda"]:
+ elif torch_device in ["cuda", "xpu"]:
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
else:
@@ -159,9 +159,15 @@ def test_full_loop_device_karras_sigmas(self):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
- elif torch_device in ["cuda"]:
+ elif torch_device in ["cuda", "xpu"]:
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
else:
assert abs(result_sum.item() - 170.3135223388672) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py
index 873eaecd0a5c..7cbaa5cc5e8d 100644
--- a/tests/schedulers/test_scheduler_dpm_single.py
+++ b/tests/schedulers/test_scheduler_dpm_single.py
@@ -1,4 +1,5 @@
import tempfile
+import unittest
import torch
@@ -65,6 +66,7 @@ def check_over_configs(self, time_step=0, **config):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+ @unittest.skip("Test not supported.")
def test_from_save_pretrained(self):
pass
@@ -346,3 +348,9 @@ def test_custom_timesteps(self):
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
index b5522f5991f7..e97d64ec5f1d 100644
--- a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
+++ b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
@@ -3,9 +3,7 @@
import torch
-from diffusers import (
- EDMDPMSolverMultistepScheduler,
-)
+from diffusers import EDMDPMSolverMultistepScheduler
from .test_schedulers import SchedulerCommonTest
@@ -63,6 +61,7 @@ def check_over_configs(self, time_step=0, **config):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+ @unittest.skip("Test not supported.")
def test_from_save_pretrained(self):
pass
@@ -258,5 +257,6 @@ def test_duplicated_timesteps(self, **config):
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps) == scheduler.num_inference_steps
+ @unittest.skip("Test not supported.")
def test_trained_betas(self):
pass
diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py
index fbb49b164165..4c7e02442cd0 100644
--- a/tests/schedulers/test_scheduler_euler.py
+++ b/tests/schedulers/test_scheduler_euler.py
@@ -263,3 +263,9 @@ def test_custom_sigmas(self):
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py
index d2ee7e13146d..8ccb5f6594a5 100644
--- a/tests/schedulers/test_scheduler_flax.py
+++ b/tests/schedulers/test_scheduler_flax.py
@@ -338,8 +338,8 @@ def test_full_loop_no_noise(self):
assert abs(result_sum - 255.0714) < 1e-2
assert abs(result_mean - 0.332124) < 1e-3
else:
- assert abs(result_sum - 255.1113) < 1e-1
- assert abs(result_mean - 0.332176) < 1e-3
+ assert abs(result_sum - 270.2) < 1e-1
+ assert abs(result_mean - 0.3519494) < 1e-3
@require_flax
@@ -675,6 +675,7 @@ def check_over_configs(self, time_step=0, **config):
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+ @unittest.skip("Test not supported.")
def test_from_save_pretrained(self):
pass
diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py
index a3689ef2ea63..9e060c6d476f 100644
--- a/tests/schedulers/test_scheduler_heun.py
+++ b/tests/schedulers/test_scheduler_heun.py
@@ -219,3 +219,9 @@ def test_custom_timesteps(self):
assert (
torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_ipndm.py b/tests/schedulers/test_scheduler_ipndm.py
index 87c8da3ee3c1..ac7973c58295 100644
--- a/tests/schedulers/test_scheduler_ipndm.py
+++ b/tests/schedulers/test_scheduler_ipndm.py
@@ -1,4 +1,5 @@
import tempfile
+import unittest
import torch
@@ -50,6 +51,7 @@ def check_over_configs(self, time_step=0, **config):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+ @unittest.skip("Test not supported.")
def test_from_save_pretrained(self):
pass
diff --git a/tests/schedulers/test_scheduler_kdpm2_ancestral.py b/tests/schedulers/test_scheduler_kdpm2_ancestral.py
index 82312629727c..fa85c2be45ed 100644
--- a/tests/schedulers/test_scheduler_kdpm2_ancestral.py
+++ b/tests/schedulers/test_scheduler_kdpm2_ancestral.py
@@ -156,3 +156,9 @@ def test_full_loop_with_noise(self):
assert abs(result_sum.item() - 93087.3437) < 1e-2, f" expected result sum 93087.3437, but get {result_sum}"
assert abs(result_mean.item() - 121.2074) < 5e-3, f" expected result mean 121.2074, but get {result_mean}"
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_kdpm2_discrete.py b/tests/schedulers/test_scheduler_kdpm2_discrete.py
index a992edcd9551..4d8923b6946b 100644
--- a/tests/schedulers/test_scheduler_kdpm2_discrete.py
+++ b/tests/schedulers/test_scheduler_kdpm2_discrete.py
@@ -164,3 +164,9 @@ def test_full_loop_with_noise(self):
assert abs(result_sum.item() - 70408.4062) < 1e-2, f" expected result sum 70408.4062, but get {result_sum}"
assert abs(result_mean.item() - 91.6776) < 1e-3, f" expected result mean 91.6776, but get {result_mean}"
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py
index c2c6530faa11..f3f6e9ba5837 100644
--- a/tests/schedulers/test_scheduler_lcm.py
+++ b/tests/schedulers/test_scheduler_lcm.py
@@ -99,7 +99,7 @@ def test_add_noise_device(self, num_inference_steps=10):
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
- noise = torch.randn_like(scaled_sample).to(torch_device)
+ noise = torch.randn(scaled_sample.shape).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)
diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py
index 5c163ce9fe7a..3bfcd57c1b6d 100644
--- a/tests/schedulers/test_scheduler_lms.py
+++ b/tests/schedulers/test_scheduler_lms.py
@@ -168,3 +168,9 @@ def test_full_loop_with_noise(self):
assert abs(result_sum.item() - 27663.6895) < 1e-2
assert abs(result_mean.item() - 36.0204) < 1e-3
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_pndm.py b/tests/schedulers/test_scheduler_pndm.py
index c1519f7c7e8e..13c690468222 100644
--- a/tests/schedulers/test_scheduler_pndm.py
+++ b/tests/schedulers/test_scheduler_pndm.py
@@ -1,4 +1,5 @@
import tempfile
+import unittest
import torch
@@ -53,6 +54,7 @@ def check_over_configs(self, time_step=0, **config):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+ @unittest.skip("Test not supported.")
def test_from_save_pretrained(self):
pass
diff --git a/tests/schedulers/test_scheduler_sasolver.py b/tests/schedulers/test_scheduler_sasolver.py
index 574194632df0..baa2736b2fcc 100644
--- a/tests/schedulers/test_scheduler_sasolver.py
+++ b/tests/schedulers/test_scheduler_sasolver.py
@@ -103,8 +103,6 @@ def test_full_loop_no_noise(self):
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 329.1999816894531) < 1e-2
assert abs(result_mean.item() - 0.4286458194255829) < 1e-3
- else:
- print("None")
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
@@ -135,8 +133,6 @@ def test_full_loop_with_v_prediction(self):
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 193.4154052734375) < 1e-2
assert abs(result_mean.item() - 0.2518429756164551) < 1e-3
- else:
- print("None")
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
@@ -166,8 +162,6 @@ def test_full_loop_device(self):
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 337.394287109375) < 1e-2
assert abs(result_mean.item() - 0.4393154978752136) < 1e-3
- else:
- print("None")
def test_full_loop_device_karras_sigmas(self):
scheduler_class = self.scheduler_classes[0]
@@ -198,5 +192,9 @@ def test_full_loop_device_karras_sigmas(self):
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 837.25537109375) < 1e-2
assert abs(result_mean.item() - 1.0901763439178467) < 1e-2
- else:
- print("None")
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_unclip.py b/tests/schedulers/test_scheduler_unclip.py
index b0ce1312e79f..9e66a328f42e 100644
--- a/tests/schedulers/test_scheduler_unclip.py
+++ b/tests/schedulers/test_scheduler_unclip.py
@@ -1,3 +1,5 @@
+import unittest
+
import torch
from diffusers import UnCLIPScheduler
@@ -130,8 +132,10 @@ def test_full_loop_skip_timesteps(self):
assert abs(result_sum.item() - 258.2044983) < 1e-2
assert abs(result_mean.item() - 0.3362038) < 1e-3
+ @unittest.skip("Test not supported.")
def test_trained_betas(self):
pass
+ @unittest.skip("Test not supported.")
def test_add_noise_device(self):
pass
diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py
index 5eb4d5ceef01..197c831cb015 100644
--- a/tests/schedulers/test_scheduler_unipc.py
+++ b/tests/schedulers/test_scheduler_unipc.py
@@ -393,3 +393,9 @@ def test_full_loop_with_noise(self):
assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}"
assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"
+
+ def test_beta_sigmas(self):
+ self.check_over_configs(use_beta_sigmas=True)
+
+ def test_exponential_sigmas(self):
+ self.check_over_configs(use_exponential_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_vq_diffusion.py b/tests/schedulers/test_scheduler_vq_diffusion.py
index 74437ad45480..c12825ba2e62 100644
--- a/tests/schedulers/test_scheduler_vq_diffusion.py
+++ b/tests/schedulers/test_scheduler_vq_diffusion.py
@@ -1,3 +1,5 @@
+import unittest
+
import torch
import torch.nn.functional as F
@@ -52,5 +54,6 @@ def test_time_indices(self):
for t in [0, 50, 99]:
self.check_over_forward(time_step=t)
+ @unittest.skip("Test not supported.")
def test_add_noise_device(self):
pass
diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py
index fc7f22d2a8e5..42ca1bc54155 100755
--- a/tests/schedulers/test_schedulers.py
+++ b/tests/schedulers/test_schedulers.py
@@ -361,7 +361,7 @@ def model(sample, t, *args):
if isinstance(t, torch.Tensor):
num_dims = len(sample.shape)
# pad t with 1s to match num_dims
- t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device).to(sample.dtype)
+ t = t.reshape(-1, *(1,) * (num_dims - 1)).to(sample.device, dtype=sample.dtype)
return sample * t / (t + 1)
@@ -722,7 +722,7 @@ def test_add_noise_device(self):
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
- noise = torch.randn_like(scaled_sample).to(torch_device)
+ noise = torch.randn(scaled_sample.shape).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 9b89578c5a8c..4e7bc0af6842 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -47,6 +47,8 @@ def download_diffusers_config(repo_id, tmpdir):
class SDSingleFileTesterMixin:
+ single_file_kwargs = {}
+
def _compare_component_configs(self, pipe, single_file_pipe):
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
@@ -154,23 +156,23 @@ def test_single_file_components_with_original_config_local_files_only(
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
- sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None)
+ sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None, **self.single_file_kwargs)
sf_pipe.unet.set_attn_processor(AttnProcessor())
- sf_pipe.enable_model_cpu_offload()
+ sf_pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
image_single_file = sf_pipe(**inputs).images[0]
pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
pipe.unet.set_attn_processor(AttnProcessor())
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
- assert max_diff < expected_max_diff
+ assert max_diff < expected_max_diff, f"{image.flatten()} != {image_single_file.flatten()}"
def test_single_file_components_with_diffusers_config(
self,
@@ -378,14 +380,14 @@ def test_single_file_components_with_diffusers_config_local_files_only(
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16, safety_checker=None)
sf_pipe.unet.set_default_attn_processor()
- sf_pipe.enable_model_cpu_offload()
+ sf_pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
image_single_file = sf_pipe(**inputs).images[0]
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16, safety_checker=None)
pipe.unet.set_default_attn_processor()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py
new file mode 100644
index 000000000000..78e68c4c2df0
--- /dev/null
+++ b/tests/single_file/test_lumina2_transformer.py
@@ -0,0 +1,74 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+
+from diffusers import (
+ Lumina2Transformer2DModel,
+)
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+@require_torch_accelerator
+class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
+ model_class = Lumina2Transformer2DModel
+ ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
+ alternate_keys_ckpt_paths = [
+ "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
+ ]
+
+ repo_id = "Alpha-VLLM/Lumina-Image-2.0"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_components(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
+ model_single_file = self.model_class.from_single_file(self.ckpt_path)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
+
+ def test_checkpoint_loading(self):
+ for ckpt_path in self.alternate_keys_ckpt_paths:
+ torch.cuda.empty_cache()
+ model = self.model_class.from_single_file(ckpt_path)
+
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py
new file mode 100644
index 000000000000..b1faeb78776b
--- /dev/null
+++ b/tests/single_file/test_model_autoencoder_dc_single_file.py
@@ -0,0 +1,126 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+
+from diffusers import (
+ AutoencoderDC,
+)
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ load_hf_numpy,
+ numpy_cosine_similarity_distance,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+@slow
+@require_torch_accelerator
+class AutoencoderDCSingleFileTests(unittest.TestCase):
+ model_class = AutoencoderDC
+ ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
+ repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_file_format(self, seed, shape):
+ return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
+
+ def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
+ dtype = torch.float16 if fp16 else torch.float32
+ image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
+ return image
+
+ def test_single_file_inference_same_as_pretrained(self):
+ model_1 = self.model_class.from_pretrained(self.repo_id).to(torch_device)
+ model_2 = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id).to(torch_device)
+
+ image = self.get_sd_image(33)
+
+ with torch.no_grad():
+ sample_1 = model_1(image).sample
+ sample_2 = model_2(image).sample
+
+ assert sample_1.shape == sample_2.shape
+
+ output_slice_1 = sample_1.flatten().float().cpu()
+ output_slice_2 = sample_2.flatten().float().cpu()
+
+ assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
+
+ def test_single_file_components(self):
+ model = self.model_class.from_pretrained(self.repo_id)
+ model_single_file = self.model_class.from_single_file(self.ckpt_path)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
+
+ def test_single_file_in_type_variant_components(self):
+ # `in` variant checkpoints require passing in a `config` parameter
+ # in order to set the scaling factor correctly.
+ # `in` and `mix` variants have the same keys and we cannot automatically infer a scaling factor.
+ # We default to using teh `mix` config
+ repo_id = "mit-han-lab/dc-ae-f128c512-in-1.0-diffusers"
+ ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors"
+
+ model = self.model_class.from_pretrained(repo_id)
+ model_single_file = self.model_class.from_single_file(ckpt_path, config=repo_id)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
+
+ def test_single_file_mix_type_variant_components(self):
+ repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"
+ ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0/blob/main/model.safetensors"
+
+ model = self.model_class.from_pretrained(repo_id)
+ model_single_file = self.model_class.from_single_file(ckpt_path, config=repo_id)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between pretrained loading and single file loading"
diff --git a/tests/single_file/test_model_controlnet_single_file.py b/tests/single_file/test_model_controlnet_single_file.py
index 1d5b790ebb4a..bfcb802380a6 100644
--- a/tests/single_file/test_model_controlnet_single_file.py
+++ b/tests/single_file/test_model_controlnet_single_file.py
@@ -22,9 +22,11 @@
ControlNetModel,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
@@ -32,7 +34,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class ControlNetModelSingleFileTests(unittest.TestCase):
model_class = ControlNetModel
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
@@ -41,12 +43,12 @@ class ControlNetModelSingleFileTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py
new file mode 100644
index 000000000000..0ec97db26a9e
--- /dev/null
+++ b/tests/single_file/test_model_flux_transformer_single_file.py
@@ -0,0 +1,72 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+
+from diffusers import (
+ FluxTransformer2DModel,
+)
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+@require_torch_accelerator
+class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
+ model_class = FluxTransformer2DModel
+ ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
+ alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
+
+ repo_id = "black-forest-labs/FLUX.1-dev"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_components(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
+ model_single_file = self.model_class.from_single_file(self.ckpt_path)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
+
+ def test_checkpoint_loading(self):
+ for ckpt_path in self.alternate_keys_ckpt_paths:
+ torch.cuda.empty_cache()
+ model = self.model_class.from_single_file(ckpt_path)
+
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
diff --git a/tests/single_file/test_model_sd_cascade_unet_single_file.py b/tests/single_file/test_model_sd_cascade_unet_single_file.py
index 43253eb6d59f..08b04e3cd7e8 100644
--- a/tests/single_file/test_model_sd_cascade_unet_single_file.py
+++ b/tests/single_file/test_model_sd_cascade_unet_single_file.py
@@ -21,9 +21,11 @@
from diffusers import StableCascadeUNet
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
@@ -33,17 +35,17 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableCascadeUNetSingleFileTest(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_single_file_components_stage_b(self):
model_single_file = StableCascadeUNet.from_single_file(
diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py
index 63f2bb757472..9db4cddb3c9d 100644
--- a/tests/single_file/test_model_vae_single_file.py
+++ b/tests/single_file/test_model_vae_single_file.py
@@ -22,10 +22,11 @@
AutoencoderKL,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -35,7 +36,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class AutoencoderKLSingleFileTests(unittest.TestCase):
model_class = AutoencoderKL
ckpt_path = (
@@ -48,12 +49,12 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py
new file mode 100644
index 000000000000..f5720ddd3964
--- /dev/null
+++ b/tests/single_file/test_model_wan_autoencoder_single_file.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+from diffusers import (
+ AutoencoderKLWan,
+)
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+@require_torch_accelerator
+class AutoencoderKLWanSingleFileTests(unittest.TestCase):
+ model_class = AutoencoderKLWan
+ ckpt_path = (
+ "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
+ )
+ repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_components(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
+ model_single_file = self.model_class.from_single_file(self.ckpt_path)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py
new file mode 100644
index 000000000000..9b938aa1754c
--- /dev/null
+++ b/tests/single_file/test_model_wan_transformer3d_single_file.py
@@ -0,0 +1,93 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import torch
+
+from diffusers import (
+ WanTransformer3DModel,
+)
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_big_gpu_with_torch_cuda,
+ require_torch_accelerator,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+@require_torch_accelerator
+class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
+ model_class = WanTransformer3DModel
+ ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
+ repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_components(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
+ model_single_file = self.model_class.from_single_file(self.ckpt_path)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
+
+
+@require_big_gpu_with_torch_cuda
+@require_torch_accelerator
+class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
+ model_class = WanTransformer3DModel
+ ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
+ repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
+ torch_dtype = torch.float8_e4m3fn
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_components(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype)
+ model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py
new file mode 100644
index 000000000000..7695e1577711
--- /dev/null
+++ b/tests/single_file/test_sana_transformer.py
@@ -0,0 +1,61 @@
+import gc
+import unittest
+
+import torch
+
+from diffusers import (
+ SanaTransformer2DModel,
+)
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+@require_torch_accelerator
+class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
+ model_class = SanaTransformer2DModel
+ ckpt_path = (
+ "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
+ )
+ alternate_keys_ckpt_paths = [
+ "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
+ ]
+
+ repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_components(self):
+ model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
+ model_single_file = self.model_class.from_single_file(self.ckpt_path)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert (
+ model.config[param_name] == param_value
+ ), f"{param_name} differs between single file loading and pretrained loading"
+
+ def test_checkpoint_loading(self):
+ for ckpt_path in self.alternate_keys_ckpt_paths:
+ torch.cuda.empty_cache()
+ model = self.model_class.from_single_file(ckpt_path)
+
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
index 332bcfbe03b6..7589b48028c2 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
@@ -8,9 +8,10 @@
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -27,7 +28,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
@@ -41,12 +42,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -75,14 +76,14 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe.unet.set_default_attn_processor()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe_sf = self.pipeline_class.from_single_file(
self.ckpt_path,
controlnet=controlnet,
)
pipe_sf.unet.set_default_attn_processor()
- pipe_sf.enable_model_cpu_offload()
+ pipe_sf.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
output = pipe(**inputs).images[0]
diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
index c0d70123b286..1555831db6db 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
@@ -8,10 +8,12 @@
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import (
@@ -26,7 +28,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
@@ -36,12 +38,12 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self):
control_image = load_image(
@@ -71,11 +73,11 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet, safety_checker=None)
pipe.unet.set_default_attn_processor()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe_sf = self.pipeline_class.from_single_file(self.ckpt_path, controlnet=controlnet, safety_checker=None)
pipe_sf.unet.set_default_attn_processor()
- pipe_sf.enable_model_cpu_offload()
+ pipe_sf.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs()
output = pipe(**inputs).images[0]
diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
index 3b5cf910b080..2c1e414e5e36 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
@@ -8,10 +8,12 @@
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import (
@@ -26,7 +28,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
@@ -40,12 +42,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self):
control_image = load_image(
@@ -65,14 +67,14 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe.unet.set_default_attn_processor()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
pipe_sf = self.pipeline_class.from_single_file(
self.ckpt_path,
controlnet=controlnet,
)
pipe_sf.unet.set_default_attn_processor()
- pipe_sf.enable_model_cpu_offload()
+ pipe_sf.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs()
output = pipe(**inputs).images[0]
diff --git a/tests/single_file/test_stable_diffusion_img2img_single_file.py b/tests/single_file/test_stable_diffusion_img2img_single_file.py
index 04f36f255014..9ad935582409 100644
--- a/tests/single_file/test_stable_diffusion_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_img2img_single_file.py
@@ -8,9 +8,11 @@
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -20,7 +22,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = (
@@ -34,12 +36,12 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -63,7 +65,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
@@ -73,12 +75,12 @@ class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDS
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/single_file/test_stable_diffusion_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_inpaint_single_file.py
index 5c6734a9a33e..b05a098c0bcb 100644
--- a/tests/single_file/test_stable_diffusion_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_inpaint_single_file.py
@@ -8,9 +8,11 @@
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -20,7 +22,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
@@ -30,12 +32,12 @@ class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSin
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -78,7 +80,7 @@ def test_single_file_components_with_original_config_local_files_only(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = (
@@ -90,12 +92,12 @@ class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDS
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py
index e46e87e18c18..78baeb94929c 100644
--- a/tests/single_file/test_stable_diffusion_single_file.py
+++ b/tests/single_file/test_stable_diffusion_single_file.py
@@ -4,12 +4,16 @@
import torch
-from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
+from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ nightly,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import (
@@ -23,7 +27,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = (
@@ -37,12 +41,12 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -95,12 +99,12 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -116,3 +120,45 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
+
+
+@nightly
+@slow
+@require_torch_accelerator
+class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+ pipeline_class = StableDiffusionInstructPix2PixPipeline
+ ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
+ original_config = (
+ "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/refs/heads/main/configs/generate.yaml"
+ )
+ repo_id = "timbrooks/instruct-pix2pix"
+ single_file_kwargs = {"extract_ema": True}
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
+ image = load_image(
+ "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_pix2pix/example.jpg"
+ )
+ inputs = {
+ "prompt": "turn him into a cyborg",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 3,
+ "guidance_scale": 7.5,
+ "image_guidance_scale": 1.0,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_single_file_format_inference_is_same_as_pretrained(self):
+ super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
diff --git a/tests/single_file/test_stable_diffusion_upscale_single_file.py b/tests/single_file/test_stable_diffusion_upscale_single_file.py
index 3c26d001c2b0..398fc9ece359 100644
--- a/tests/single_file/test_stable_diffusion_upscale_single_file.py
+++ b/tests/single_file/test_stable_diffusion_upscale_single_file.py
@@ -1,6 +1,7 @@
import gc
import unittest
+import pytest
import torch
from diffusers import (
@@ -8,10 +9,12 @@
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -21,7 +24,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionUpscalePipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
@@ -31,12 +34,12 @@ class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSin
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_single_file_format_inference_is_same_as_pretrained(self):
image = load_image(
@@ -46,14 +49,14 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
prompt = "a cat sitting on a park bench"
pipe = StableDiffusionUpscalePipeline.from_pretrained(self.repo_id)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
generator = torch.Generator("cpu").manual_seed(0)
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
image_from_pretrained = output.images[0]
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(self.ckpt_path)
- pipe_from_single_file.enable_model_cpu_offload()
+ pipe_from_single_file.enable_model_cpu_offload(device=torch_device)
generator = torch.Generator("cpu").manual_seed(0)
output_from_single_file = pipe_from_single_file(
@@ -66,3 +69,19 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
assert (
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
)
+
+ @pytest.mark.xfail(
+ condition=True,
+ reason="Test fails because of mismatches in the configs but it is very hard to properly fix this considering downstream usecase.",
+ strict=True,
+ )
+ def test_single_file_components_with_original_config(self):
+ super().test_single_file_components_with_original_config()
+
+ @pytest.mark.xfail(
+ condition=True,
+ reason="Test fails because of mismatches in the configs but it is very hard to properly fix this considering downstream usecase.",
+ strict=True,
+ )
+ def test_single_file_components_with_original_config_local_files_only(self):
+ super().test_single_file_components_with_original_config_local_files_only()
diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
index ead77a1d6553..fb5f8725b86e 100644
--- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
@@ -11,10 +11,12 @@
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import (
@@ -29,7 +31,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLAdapterPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
@@ -41,12 +43,12 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self):
prompt = "toy"
@@ -74,7 +76,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
torch_dtype=torch.float16,
safety_checker=None,
)
- pipe_single_file.enable_model_cpu_offload()
+ pipe_single_file.enable_model_cpu_offload(device=torch_device)
pipe_single_file.set_progress_bar_config(disable=None)
inputs = self.get_inputs()
@@ -86,7 +88,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
torch_dtype=torch.float16,
safety_checker=None,
)
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs()
images = pipe(**inputs).images[0]
diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
index 9491adf2dfa4..6d8c4369e1e1 100644
--- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
@@ -8,9 +8,10 @@
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -26,7 +27,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLControlNetPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
@@ -38,12 +39,12 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -68,7 +69,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
self.ckpt_path, controlnet=controlnet, torch_dtype=torch.float16
)
pipe_single_file.unet.set_default_attn_processor()
- pipe_single_file.enable_model_cpu_offload()
+ pipe_single_file.enable_model_cpu_offload(device=torch_device)
pipe_single_file.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
@@ -76,7 +77,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet, torch_dtype=torch.float16)
pipe.unet.set_default_attn_processor()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_inputs(torch_device)
images = pipe(**inputs).images[0]
diff --git a/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py b/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
index 71b57eb7c6c9..7df8b84bc235 100644
--- a/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
@@ -9,10 +9,12 @@
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import SDXLSingleFileTesterMixin
@@ -22,7 +24,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
@@ -34,12 +36,12 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -63,7 +65,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests(unittest.TestCase):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = (
@@ -83,7 +85,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_default_attn_processor()
- pipe.enable_model_cpu_offload()
+ pipe.enable_model_cpu_offload(device=torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(
@@ -93,7 +95,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16)
pipe_single_file.scheduler = DDIMScheduler.from_config(pipe_single_file.scheduler.config)
pipe_single_file.unet.set_default_attn_processor()
- pipe_single_file.enable_model_cpu_offload()
+ pipe_single_file.enable_model_cpu_offload(device=torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
image_single_file = pipe_single_file(
diff --git a/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py b/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
index 7ebddc8555bb..5a014638633b 100644
--- a/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
+++ b/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
@@ -5,9 +5,11 @@
from diffusers import StableDiffusionXLInstructPix2PixPipeline
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
@@ -15,7 +17,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
@@ -25,12 +27,12 @@ class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/single_file/test_stable_diffusion_xl_single_file.py b/tests/single_file/test_stable_diffusion_xl_single_file.py
index a143a35a2bbc..77f58d859209 100644
--- a/tests/single_file/test_stable_diffusion_xl_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_single_file.py
@@ -7,9 +7,11 @@
StableDiffusionXLPipeline,
)
from diffusers.utils.testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
+ torch_device,
)
from .single_file_testing_utils import SDXLSingleFileTesterMixin
@@ -19,7 +21,7 @@
@slow
-@require_torch_gpu
+@require_torch_accelerator
class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
@@ -31,12 +33,12 @@ class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingle
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py
index 626a9a468572..d39fe6a618d4 100644
--- a/utils/check_config_docstrings.py
+++ b/utils/check_config_docstrings.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/check_copies.py b/utils/check_copies.py
index 20449e790db2..001366c1905f 100644
--- a/utils/check_copies.py
+++ b/utils/check_copies.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/check_doc_toc.py b/utils/check_doc_toc.py
index 35ded936650d..d7c9cee82fcb 100644
--- a/utils/check_doc_toc.py
+++ b/utils/check_doc_toc.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/check_dummies.py b/utils/check_dummies.py
index af99eeb05c6d..04a670c2f5d9 100644
--- a/utils/check_dummies.py
+++ b/utils/check_dummies.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/check_inits.py b/utils/check_inits.py
index 2c514046afaa..8208fa634186 100644
--- a/utils/check_inits.py
+++ b/utils/check_inits.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 597893f267ca..14bdbe60adf0 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/check_support_list.py b/utils/check_support_list.py
new file mode 100644
index 000000000000..89cfce62de0b
--- /dev/null
+++ b/utils/check_support_list.py
@@ -0,0 +1,124 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+"""
+Utility that checks that modules like attention processors are listed in the documentation file.
+
+```bash
+python utils/check_support_list.py
+```
+
+It has no auto-fix mode.
+"""
+
+import os
+import re
+
+
+# All paths are set with the intent that you run this script from the root of the repo
+REPO_PATH = "."
+
+
+def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"):
+ """
+ Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class.
+ Returns a list of documented class names (just the class name portion).
+ """
+ with open(os.path.join(REPO_PATH, doc_path), "r") as f:
+ doctext = f.read()
+ matches = re.findall(autodoc_regex, doctext)
+ return [match.split(".")[-1] for match in matches]
+
+
+def read_source_classes(src_path, class_regex, exclude_conditions=None):
+ """
+ Reads class names from a source file using a regex that captures class definitions.
+ Optionally exclude classes based on a list of conditions (functions that take class name and return bool).
+ """
+ if exclude_conditions is None:
+ exclude_conditions = []
+ with open(os.path.join(REPO_PATH, src_path), "r") as f:
+ doctext = f.read()
+ classes = re.findall(class_regex, doctext)
+ # Filter out classes that meet any of the exclude conditions
+ filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)]
+ return filtered_classes
+
+
+def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None):
+ """
+ Generic function to check if all classes defined in `src_path` are documented in `doc_path`.
+ Returns a set of undocumented class names.
+ """
+ documented = set(read_documented_classes(doc_path, doc_regex))
+ source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions))
+
+ # Find which classes in source are not documented in a deterministic way.
+ undocumented = sorted(source_classes - documented)
+ return undocumented
+
+
+if __name__ == "__main__":
+ # Define the checks we need to perform
+ checks = {
+ "Attention Processors": {
+ "doc_path": "docs/source/en/api/attnprocessor.md",
+ "src_path": "src/diffusers/models/attention_processor.py",
+ "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
+ "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
+ "exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"],
+ },
+ "Image Processors": {
+ "doc_path": "docs/source/en/api/image_processor.md",
+ "src_path": "src/diffusers/image_processor.py",
+ "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
+ "src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
+ },
+ "Activations": {
+ "doc_path": "docs/source/en/api/activations.md",
+ "src_path": "src/diffusers/models/activations.py",
+ "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
+ "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
+ },
+ "Normalizations": {
+ "doc_path": "docs/source/en/api/normalization.md",
+ "src_path": "src/diffusers/models/normalization.py",
+ "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
+ "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
+ "exclude_conditions": [
+ # Exclude LayerNorm as it's an intentional exception
+ lambda c: c == "LayerNorm"
+ ],
+ },
+ "LoRA Mixins": {
+ "doc_path": "docs/source/en/api/loaders/lora.md",
+ "src_path": "src/diffusers/loaders/lora_pipeline.py",
+ "doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
+ "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
+ },
+ }
+
+ missing_items = {}
+ for category, params in checks.items():
+ undocumented = check_documentation(
+ doc_path=params["doc_path"],
+ src_path=params["src_path"],
+ doc_regex=params["doc_regex"],
+ src_regex=params["src_regex"],
+ exclude_conditions=params.get("exclude_conditions"),
+ )
+ if undocumented:
+ missing_items[category] = undocumented
+
+ # If we have any missing items, raise a single combined error
+ if missing_items:
+ error_msg = ["Some classes are not documented properly:\n"]
+ for category, classes in missing_items.items():
+ error_msg.append(f"- {category}: {', '.join(sorted(classes))}")
+ raise ValueError("\n".join(error_msg))
diff --git a/utils/check_table.py b/utils/check_table.py
index 80fd5660bb46..83c29aa74eca 100644
--- a/utils/check_table.py
+++ b/utils/check_table.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py
index 6c2bb7f5d69c..791df0e78694 100644
--- a/utils/custom_init_isort.py
+++ b/utils/custom_init_isort.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/extract_tests_from_mixin.py b/utils/extract_tests_from_mixin.py
new file mode 100644
index 000000000000..c8b65b96ee16
--- /dev/null
+++ b/utils/extract_tests_from_mixin.py
@@ -0,0 +1,61 @@
+import argparse
+import inspect
+import sys
+from pathlib import Path
+from typing import List, Type
+
+
+root_dir = Path(__file__).parent.parent.absolute()
+sys.path.insert(0, str(root_dir))
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--type", type=str, default=None)
+args = parser.parse_args()
+
+
+def get_test_methods_from_class(cls: Type) -> List[str]:
+ """
+ Get all test method names from a given class.
+ Only returns methods that start with 'test_'.
+ """
+ test_methods = []
+ for name, obj in inspect.getmembers(cls):
+ if name.startswith("test_") and inspect.isfunction(obj):
+ test_methods.append(name)
+ return sorted(test_methods)
+
+
+def generate_pytest_pattern(test_methods: List[str]) -> str:
+ """Generate pytest pattern string for the -k flag."""
+ return " or ".join(test_methods)
+
+
+def generate_pattern_for_mixin(mixin_class: Type) -> str:
+ """
+ Generate pytest pattern for a specific mixin class.
+ """
+ if mixin_cls is None:
+ return ""
+ test_methods = get_test_methods_from_class(mixin_class)
+ return generate_pytest_pattern(test_methods)
+
+
+if __name__ == "__main__":
+ mixin_cls = None
+ if args.type == "pipeline":
+ from tests.pipelines.test_pipelines_common import PipelineTesterMixin
+
+ mixin_cls = PipelineTesterMixin
+
+ elif args.type == "models":
+ from tests.models.test_modeling_common import ModelTesterMixin
+
+ mixin_cls = ModelTesterMixin
+
+ elif args.type == "lora":
+ from tests.lora.utils import PeftLoraLoaderMixinTests
+
+ mixin_cls = PeftLoraLoaderMixinTests
+
+ pattern = generate_pattern_for_mixin(mixin_cls)
+ print(pattern)
diff --git a/utils/fetch_torch_cuda_pipeline_test_matrix.py b/utils/fetch_torch_cuda_pipeline_test_matrix.py
index e6a9c4b6a3bd..196f35628ac1 100644
--- a/utils/fetch_torch_cuda_pipeline_test_matrix.py
+++ b/utils/fetch_torch_cuda_pipeline_test_matrix.py
@@ -12,16 +12,14 @@
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
ALWAYS_TEST_PIPELINE_MODULES = [
"controlnet",
+ "controlnet_flux",
+ "controlnet_sd3",
"stable_diffusion",
"stable_diffusion_2",
+ "stable_diffusion_3",
"stable_diffusion_xl",
- "stable_diffusion_adapter",
- "deepfloyd_if",
"ip_adapters",
- "kandinsky",
- "kandinsky2_2",
- "text_to_video_synthesis",
- "wuerstchen",
+ "flux",
]
PIPELINE_USAGE_CUTOFF = int(os.getenv("PIPELINE_USAGE_CUTOFF", 50000))
diff --git a/utils/get_modified_files.py b/utils/get_modified_files.py
index a252bc648be5..e392e50c12d3 100644
--- a/utils/get_modified_files.py
+++ b/utils/get_modified_files.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/overwrite_expected_slice.py b/utils/overwrite_expected_slice.py
index 07778a05b1ee..723c1c98fc21 100644
--- a/utils/overwrite_expected_slice.py
+++ b/utils/overwrite_expected_slice.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/print_env.py b/utils/print_env.py
index 3e4495c98094..0a1cfbef133f 100644
--- a/utils/print_env.py
+++ b/utils/print_env.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,6 +37,10 @@
print("Cuda version:", torch.version.cuda)
print("CuDNN version:", torch.backends.cudnn.version())
print("Number of GPUs available:", torch.cuda.device_count())
+ if torch.cuda.is_available():
+ device_properties = torch.cuda.get_device_properties(0)
+ total_memory = device_properties.total_memory / (1024**3)
+ print(f"CUDA memory: {total_memory} GB")
except ImportError:
print("Torch version:", None)
diff --git a/utils/stale.py b/utils/stale.py
index c01b6d5682e9..20cb6cabeb91 100644
--- a/utils/stale.py
+++ b/utils/stale.py
@@ -24,6 +24,7 @@
LABELS_TO_EXEMPT = [
+ "close-to-merge",
"good first issue",
"good second issue",
"good difficult issue",
diff --git a/utils/update_metadata.py b/utils/update_metadata.py
index 103a2b9ab0cc..a97e65801c5f 100644
--- a/utils/update_metadata.py
+++ b/utils/update_metadata.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.